From 10bdf8212c50fb7f1c831b11fabb3206a8447860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannes=20H=C3=B6ke?= Date: Tue, 25 Oct 2016 19:28:34 +0200 Subject: [PATCH] Add pass_user_data and pass_chat_data to Handler (#436) * initial commit for user_data * add chat_data and use defaultdict * fix chat_data copy-paste error * add test for user_data and chat_data * fix case where chat is None * remove braces from import line --- examples/conversationbot2.py | 152 ++++++++++++++++++++++ telegram/ext/callbackqueryhandler.py | 20 ++- telegram/ext/choseninlineresulthandler.py | 23 +++- telegram/ext/commandhandler.py | 20 ++- telegram/ext/conversationhandler.py | 25 +--- telegram/ext/dispatcher.py | 6 + telegram/ext/handler.py | 29 ++++- telegram/ext/inlinequeryhandler.py | 20 ++- telegram/ext/messagehandler.py | 20 ++- telegram/ext/regexhandler.py | 22 +++- telegram/utils/helpers.py | 44 +++++++ tests/test_updater.py | 21 ++- 12 files changed, 357 insertions(+), 45 deletions(-) create mode 100644 examples/conversationbot2.py create mode 100644 telegram/utils/helpers.py diff --git a/examples/conversationbot2.py b/examples/conversationbot2.py new file mode 100644 index 000000000..3b82dfb16 --- /dev/null +++ b/examples/conversationbot2.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Simple Bot to reply to Telegram messages +# This program is dedicated to the public domain under the CC0 license. +""" +This Bot uses the Updater class to handle the bot. + +First, a few callback functions are defined. Then, those functions are passed to +the Dispatcher and registered at their respective places. +Then, the bot is started and runs until we press Ctrl-C on the command line. + +Usage: +Example of a bot-user conversation using ConversationHandler. +Send /start to initiate the conversation. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +from telegram import ReplyKeyboardMarkup +from telegram.ext import (Updater, CommandHandler, MessageHandler, Filters, RegexHandler, + ConversationHandler) + +import logging + +# Enable logging +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + level=logging.INFO) + +logger = logging.getLogger(__name__) + +CHOOSING, TYPING_REPLY, TYPING_CHOICE = range(3) + +reply_keyboard = [['Age', 'Favourite colour'], + ['Number of siblings', 'Something else...'], + ['Done']] +markup = ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True) + + +def facts_to_str(user_data): + facts = list() + + for key, value in user_data.items(): + facts.append('%s - %s' % (key, value)) + + return "\n".join(facts).join(['\n', '\n']) + + +def start(bot, update): + update.message.reply_text( + "Hi! My name is Doctor Botter. I will hold a more complex conversation with you. " + "Why don't you tell me something about yourself?", + reply_markup=markup) + + return CHOOSING + + +def regular_choice(bot, update, user_data): + text = update.message.text + user_data['choice'] = text + update.message.reply_text('Your %s? Yes, I would love to hear about that!' % text.lower()) + + return TYPING_REPLY + + +def custom_choice(bot, update): + update.message.reply_text('Alright, please send me the category first, ' + 'for example "Most impressive skill"') + + return TYPING_CHOICE + + +def received_information(bot, update, user_data): + text = update.message.text + category = user_data['choice'] + user_data[category] = text + del user_data['choice'] + + update.message.reply_text("Neat! Just so you know, this is what you already told me:" + "%s" + "You can tell me more, or change your opinion on something." + % facts_to_str(user_data), + reply_markup=markup) + + return CHOOSING + + +def done(bot, update, user_data): + if 'choice' in user_data: + del user_data['choice'] + + update.message.reply_text("I learned these facts about you:" + "%s" + "Until next time!" % facts_to_str(user_data)) + + user_data.clear() + return ConversationHandler.END + + +def error(bot, update, error): + logger.warn('Update "%s" caused error "%s"' % (update, error)) + + +def main(): + # Create the Updater and pass it your bot's token. + updater = Updater("TOKEN") + + # Get the dispatcher to register handlers + dp = updater.dispatcher + + # Add conversation handler with the states GENDER, PHOTO, LOCATION and BIO + conv_handler = ConversationHandler( + entry_points=[CommandHandler('start', start)], + + states={ + CHOOSING: [RegexHandler('^(Age|Favourite colour|Number of siblings)$', + regular_choice, + pass_user_data=True), + RegexHandler('^Something else...$', + custom_choice), + ], + + TYPING_CHOICE: [MessageHandler([Filters.text], + regular_choice, + pass_user_data=True), + ], + + TYPING_REPLY: [MessageHandler([Filters.text], + received_information, + pass_user_data=True), + ], + }, + + fallbacks=[RegexHandler('^Done$', done, pass_user_data=True)] + ) + + dp.add_handler(conv_handler) + + # log all errors + dp.add_error_handler(error) + + # Start the Bot + updater.start_polling() + + # Run the bot until the you presses Ctrl-C or the process receives SIGINT, + # SIGTERM or SIGABRT. This should be used most of the time, since + # start_polling() is non-blocking and will stop the bot gracefully. + updater.idle() + + +if __name__ == '__main__': + main() diff --git a/telegram/ext/callbackqueryhandler.py b/telegram/ext/callbackqueryhandler.py index 6e2caf38c..850bacc3a 100644 --- a/telegram/ext/callbackqueryhandler.py +++ b/telegram/ext/callbackqueryhandler.py @@ -52,6 +52,14 @@ class CallbackQueryHandler(Handler): pass_groupdict (optional[bool]): If the callback should be passed the result of ``re.match(pattern, data).groupdict()`` as a keyword argument called ``groupdict``. Default is ``False`` + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ def __init__(self, @@ -60,9 +68,15 @@ class CallbackQueryHandler(Handler): pass_job_queue=False, pattern=None, pass_groups=False, - pass_groupdict=False): + pass_groupdict=False, + pass_user_data=False, + pass_chat_data=False): super(CallbackQueryHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) if isinstance(pattern, string_types): pattern = re.compile(pattern) @@ -81,7 +95,7 @@ class CallbackQueryHandler(Handler): return True def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) if self.pattern: match = re.match(self.pattern, update.callback_query.data) diff --git a/telegram/ext/choseninlineresulthandler.py b/telegram/ext/choseninlineresulthandler.py index e3a673133..4c9771955 100644 --- a/telegram/ext/choseninlineresulthandler.py +++ b/telegram/ext/choseninlineresulthandler.py @@ -40,17 +40,34 @@ class ChosenInlineResultHandler(Handler): ``job_queue`` will be passed to the callback function. It will be a ``JobQueue`` instance created by the ``Updater`` which can be used to schedule new jobs. Default is ``False``. + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ - def __init__(self, callback, pass_update_queue=False, pass_job_queue=False): + def __init__(self, + callback, + pass_update_queue=False, + pass_job_queue=False, + pass_user_data=False, + pass_chat_data=False): super(ChosenInlineResultHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) def check_update(self, update): return isinstance(update, Update) and update.chosen_inline_result def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) return self.callback(dispatcher.bot, update, **optional_args) diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py index 1b1e1e74e..28dc767cc 100644 --- a/telegram/ext/commandhandler.py +++ b/telegram/ext/commandhandler.py @@ -49,6 +49,14 @@ class CommandHandler(Handler): ``job_queue`` will be passed to the callback function. It will be a ``JobQueue`` instance created by the ``Updater`` which can be used to schedule new jobs. Default is ``False``. + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ def __init__(self, @@ -57,9 +65,15 @@ class CommandHandler(Handler): allow_edited=False, pass_args=False, pass_update_queue=False, - pass_job_queue=False): + pass_job_queue=False, + pass_user_data=False, + pass_chat_data=False): super(CommandHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) self.command = command self.allow_edited = allow_edited self.pass_args = pass_args @@ -76,7 +90,7 @@ class CommandHandler(Handler): return False def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) message = update.message or update.edited_message diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index 6afa90575..387ae296f 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -22,6 +22,7 @@ import logging from telegram import Update from telegram.ext import Handler +from telegram.utils.helpers import extract_chat_and_user from telegram.utils.promise import Promise @@ -115,29 +116,7 @@ class ConversationHandler(Handler): if not isinstance(update, Update): return False - user = None - chat = None - - if update.message: - user = update.message.from_user - chat = update.message.chat - - elif update.edited_message: - user = update.edited_message.from_user - chat = update.edited_message.chat - - elif update.inline_query: - user = update.inline_query.from_user - - elif update.chosen_inline_result: - user = update.chosen_inline_result.from_user - - elif update.callback_query: - user = update.callback_query.from_user - chat = update.callback_query.message.chat if update.callback_query.message else None - - else: - return False + chat, user = extract_chat_and_user(update) key = (chat.id, user.id) if chat else (None, user.id) state = self.conversations.get(key) diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index bd31975a6..445a4e171 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -24,6 +24,7 @@ from functools import wraps from threading import Thread, Lock, Event, current_thread, BoundedSemaphore from time import sleep from uuid import uuid4 +from collections import defaultdict from queue import Queue, Empty @@ -86,6 +87,11 @@ class Dispatcher(object): self.job_queue = job_queue self.workers = workers + self.user_data = defaultdict(dict) + """:type: dict[int, dict]""" + self.chat_data = defaultdict(dict) + """:type: dict[int, dict]""" + self.handlers = {} """:type: dict[int, list[Handler]""" self.groups = [] diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index c38a4668b..d1bb0d05f 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -20,6 +20,7 @@ Dispatcher """ from telegram.utils.deprecate import deprecate +from telegram.utils.helpers import extract_chat_and_user class Handler(object): @@ -39,12 +40,27 @@ class Handler(object): ``job_queue`` will be passed to the callback function. It will be a ``JobQueue`` instance created by the ``Updater`` which can be used to schedule new jobs. Default is ``False``. + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ - def __init__(self, callback, pass_update_queue=False, pass_job_queue=False): + def __init__(self, + callback, + pass_update_queue=False, + pass_job_queue=False, + pass_user_data=False, + pass_chat_data=False): self.callback = callback self.pass_update_queue = pass_update_queue self.pass_job_queue = pass_job_queue + self.pass_user_data = pass_user_data + self.pass_chat_data = pass_chat_data def check_update(self, update): """ @@ -74,7 +90,7 @@ class Handler(object): """ raise NotImplementedError - def collect_optional_args(self, dispatcher): + def collect_optional_args(self, dispatcher, update=None): """ Prepares the optional arguments that are the same for all types of handlers @@ -83,10 +99,19 @@ class Handler(object): dispatcher (Dispatcher): """ optional_args = dict() + if self.pass_update_queue: optional_args['update_queue'] = dispatcher.update_queue if self.pass_job_queue: optional_args['job_queue'] = dispatcher.job_queue + if self.pass_user_data or self.pass_chat_data: + chat, user = extract_chat_and_user(update) + + if self.pass_user_data: + optional_args['user_data'] = dispatcher.user_data[user.id] + + if self.pass_chat_data: + optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None] return optional_args diff --git a/telegram/ext/inlinequeryhandler.py b/telegram/ext/inlinequeryhandler.py index ce29f9c3a..c346b9139 100644 --- a/telegram/ext/inlinequeryhandler.py +++ b/telegram/ext/inlinequeryhandler.py @@ -51,6 +51,14 @@ class InlineQueryHandler(Handler): pass_groupdict (optional[bool]): If the callback should be passed the result of ``re.match(pattern, query).groupdict()`` as a keyword argument called ``groupdict``. Default is ``False`` + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ def __init__(self, @@ -59,9 +67,15 @@ class InlineQueryHandler(Handler): pass_job_queue=False, pattern=None, pass_groups=False, - pass_groupdict=False): + pass_groupdict=False, + pass_user_data=False, + pass_chat_data=False): super(InlineQueryHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) if isinstance(pattern, string_types): pattern = re.compile(pattern) @@ -80,7 +94,7 @@ class InlineQueryHandler(Handler): return True def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) if self.pattern: match = re.match(self.pattern, update.inline_query.query) diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py index a534e892d..e37ae811f 100644 --- a/telegram/ext/messagehandler.py +++ b/telegram/ext/messagehandler.py @@ -43,6 +43,14 @@ class MessageHandler(Handler): pass_update_queue (optional[bool]): If the handler should be passed the update queue as a keyword argument called ``update_queue``. It can be used to insert updates. Default is ``False`` + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ def __init__(self, @@ -50,9 +58,15 @@ class MessageHandler(Handler): callback, allow_edited=False, pass_update_queue=False, - pass_job_queue=False): + pass_job_queue=False, + pass_user_data=False, + pass_chat_data=False): super(MessageHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) self.filters = filters self.allow_edited = allow_edited @@ -83,7 +97,7 @@ class MessageHandler(Handler): return res def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) return self.callback(dispatcher.bot, update, **optional_args) diff --git a/telegram/ext/regexhandler.py b/telegram/ext/regexhandler.py index f96928956..743648054 100644 --- a/telegram/ext/regexhandler.py +++ b/telegram/ext/regexhandler.py @@ -53,6 +53,14 @@ class RegexHandler(Handler): ``job_queue`` will be passed to the callback function. It will be a ``JobQueue`` instance created by the ``Updater`` which can be used to schedule new jobs. Default is ``False``. + pass_user_data (optional[bool]): If set to ``True``, a keyword argument called + ``user_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the user that sent the update. For each update of + the same user, it will be the same ``dict``. Default is ``False``. + pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called + ``chat_data`` will be passed to the callback function. It will be a ``dict`` you + can use to keep any data related to the chat that the update was sent in. + For each update in the same chat, it will be the same ``dict``. Default is ``False``. """ def __init__(self, @@ -61,9 +69,15 @@ class RegexHandler(Handler): pass_groups=False, pass_groupdict=False, pass_update_queue=False, - pass_job_queue=False): + pass_job_queue=False, + pass_user_data=False, + pass_chat_data=False): super(RegexHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue, + pass_user_data=pass_user_data, + pass_chat_data=pass_chat_data) if isinstance(pattern, string_types): pattern = re.compile(pattern) @@ -73,14 +87,14 @@ class RegexHandler(Handler): self.pass_groupdict = pass_groupdict def check_update(self, update): - if (isinstance(update, Update) and update.message and update.message.text): + if isinstance(update, Update) and update.message and update.message.text: match = re.match(self.pattern, update.message.text) return bool(match) else: return False def handle_update(self, update, dispatcher): - optional_args = self.collect_optional_args(dispatcher) + optional_args = self.collect_optional_args(dispatcher, update) match = re.match(self.pattern, update.message.text) if self.pass_groups: diff --git a/telegram/utils/helpers.py b/telegram/utils/helpers.py new file mode 100644 index 000000000..e1105422d --- /dev/null +++ b/telegram/utils/helpers.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2016 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +""" This module contains helper functions """ + + +def extract_chat_and_user(update): + user = None + chat = None + + if update.message: + user = update.message.from_user + chat = update.message.chat + + elif update.edited_message: + user = update.edited_message.from_user + chat = update.edited_message.chat + + elif update.inline_query: + user = update.inline_query.from_user + + elif update.chosen_inline_result: + user = update.chosen_inline_result.from_user + + elif update.callback_query: + user = update.callback_query.from_user + chat = update.callback_query.message.chat if update.callback_query.message else None + + return chat, user diff --git a/tests/test_updater.py b/tests/test_updater.py index 7e5e0c183..7d3150549 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -145,6 +145,12 @@ class UpdaterTest(BaseTest, unittest.TestCase): elif args[0] == 'noresend': pass + def userAndChatDataTest(self, bot, update, user_data, chat_data): + user_data['text'] = update.message.text + chat_data['text'] = update.message.text + self.received_message = update.message.text + self.message_count += 1 + @run_async def asyncAdditionalHandlerTest(self, bot, update, update_queue=None): sleep(1) @@ -483,6 +489,19 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.assertEqual(self.received_message, '/test5 noresend') self.assertEqual(self.message_count, 2) + def test_user_and_chat_data(self): + self._setup_updater('/test_data', messages=1) + handler = CommandHandler( + 'test_data', self.userAndChatDataTest, pass_user_data=True, pass_chat_data=True) + self.updater.dispatcher.add_handler(handler) + + self.updater.start_polling(0.01) + sleep(.1) + self.assertEqual(self.received_message, '/test_data') + self.assertEqual(self.message_count, 1) + self.assertDictEqual(dict(self.updater.dispatcher.user_data), {0: {'text': '/test_data'}}) + self.assertDictEqual(dict(self.updater.dispatcher.chat_data), {0: {'text': '/test_data'}}) + def test_regexGroupHandler(self): self._setup_updater('', messages=0) d = self.updater.dispatcher @@ -771,7 +790,7 @@ class MockBot(object): self.edited = edited def mockUpdate(self, text): - message = Message(0, None, None, None) + message = Message(0, User(0, 'Testuser'), None, Chat(0, Chat.GROUP)) message.text = text update = Update(0)