Add pass_user_data and pass_chat_data to Handler (#436)

* initial commit for user_data

* add chat_data and use defaultdict

* fix chat_data copy-paste error

* add test for user_data and chat_data

* fix case where chat is None

* remove braces from import line
This commit is contained in:
Jannes Höke 2016-10-25 19:28:34 +02:00 committed by GitHub
parent 45936c9982
commit 10bdf8212c
12 changed files with 357 additions and 45 deletions

View file

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Simple Bot to reply to Telegram messages
# This program is dedicated to the public domain under the CC0 license.
"""
This Bot uses the Updater class to handle the bot.
First, a few callback functions are defined. Then, those functions are passed to
the Dispatcher and registered at their respective places.
Then, the bot is started and runs until we press Ctrl-C on the command line.
Usage:
Example of a bot-user conversation using ConversationHandler.
Send /start to initiate the conversation.
Press Ctrl-C on the command line or send a signal to the process to stop the
bot.
"""
from telegram import ReplyKeyboardMarkup
from telegram.ext import (Updater, CommandHandler, MessageHandler, Filters, RegexHandler,
ConversationHandler)
import logging
# Enable logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO)
logger = logging.getLogger(__name__)
CHOOSING, TYPING_REPLY, TYPING_CHOICE = range(3)
reply_keyboard = [['Age', 'Favourite colour'],
['Number of siblings', 'Something else...'],
['Done']]
markup = ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)
def facts_to_str(user_data):
facts = list()
for key, value in user_data.items():
facts.append('%s - %s' % (key, value))
return "\n".join(facts).join(['\n', '\n'])
def start(bot, update):
update.message.reply_text(
"Hi! My name is Doctor Botter. I will hold a more complex conversation with you. "
"Why don't you tell me something about yourself?",
reply_markup=markup)
return CHOOSING
def regular_choice(bot, update, user_data):
text = update.message.text
user_data['choice'] = text
update.message.reply_text('Your %s? Yes, I would love to hear about that!' % text.lower())
return TYPING_REPLY
def custom_choice(bot, update):
update.message.reply_text('Alright, please send me the category first, '
'for example "Most impressive skill"')
return TYPING_CHOICE
def received_information(bot, update, user_data):
text = update.message.text
category = user_data['choice']
user_data[category] = text
del user_data['choice']
update.message.reply_text("Neat! Just so you know, this is what you already told me:"
"%s"
"You can tell me more, or change your opinion on something."
% facts_to_str(user_data),
reply_markup=markup)
return CHOOSING
def done(bot, update, user_data):
if 'choice' in user_data:
del user_data['choice']
update.message.reply_text("I learned these facts about you:"
"%s"
"Until next time!" % facts_to_str(user_data))
user_data.clear()
return ConversationHandler.END
def error(bot, update, error):
logger.warn('Update "%s" caused error "%s"' % (update, error))
def main():
# Create the Updater and pass it your bot's token.
updater = Updater("TOKEN")
# Get the dispatcher to register handlers
dp = updater.dispatcher
# Add conversation handler with the states GENDER, PHOTO, LOCATION and BIO
conv_handler = ConversationHandler(
entry_points=[CommandHandler('start', start)],
states={
CHOOSING: [RegexHandler('^(Age|Favourite colour|Number of siblings)$',
regular_choice,
pass_user_data=True),
RegexHandler('^Something else...$',
custom_choice),
],
TYPING_CHOICE: [MessageHandler([Filters.text],
regular_choice,
pass_user_data=True),
],
TYPING_REPLY: [MessageHandler([Filters.text],
received_information,
pass_user_data=True),
],
},
fallbacks=[RegexHandler('^Done$', done, pass_user_data=True)]
)
dp.add_handler(conv_handler)
# log all errors
dp.add_error_handler(error)
# Start the Bot
updater.start_polling()
# Run the bot until the you presses Ctrl-C or the process receives SIGINT,
# SIGTERM or SIGABRT. This should be used most of the time, since
# start_polling() is non-blocking and will stop the bot gracefully.
updater.idle()
if __name__ == '__main__':
main()

View file

@ -52,6 +52,14 @@ class CallbackQueryHandler(Handler):
pass_groupdict (optional[bool]): If the callback should be passed the
result of ``re.match(pattern, data).groupdict()`` as a keyword
argument called ``groupdict``. Default is ``False``
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self,
@ -60,9 +68,15 @@ class CallbackQueryHandler(Handler):
pass_job_queue=False,
pattern=None,
pass_groups=False,
pass_groupdict=False):
pass_groupdict=False,
pass_user_data=False,
pass_chat_data=False):
super(CallbackQueryHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
if isinstance(pattern, string_types):
pattern = re.compile(pattern)
@ -81,7 +95,7 @@ class CallbackQueryHandler(Handler):
return True
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
if self.pattern:
match = re.match(self.pattern, update.callback_query.data)

View file

@ -40,17 +40,34 @@ class ChosenInlineResultHandler(Handler):
``job_queue`` will be passed to the callback function. It will be a ``JobQueue``
instance created by the ``Updater`` which can be used to schedule new jobs.
Default is ``False``.
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self, callback, pass_update_queue=False, pass_job_queue=False):
def __init__(self,
callback,
pass_update_queue=False,
pass_job_queue=False,
pass_user_data=False,
pass_chat_data=False):
super(ChosenInlineResultHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
def check_update(self, update):
return isinstance(update, Update) and update.chosen_inline_result
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
return self.callback(dispatcher.bot, update, **optional_args)

View file

@ -49,6 +49,14 @@ class CommandHandler(Handler):
``job_queue`` will be passed to the callback function. It will be a ``JobQueue``
instance created by the ``Updater`` which can be used to schedule new jobs.
Default is ``False``.
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self,
@ -57,9 +65,15 @@ class CommandHandler(Handler):
allow_edited=False,
pass_args=False,
pass_update_queue=False,
pass_job_queue=False):
pass_job_queue=False,
pass_user_data=False,
pass_chat_data=False):
super(CommandHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
self.command = command
self.allow_edited = allow_edited
self.pass_args = pass_args
@ -76,7 +90,7 @@ class CommandHandler(Handler):
return False
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
message = update.message or update.edited_message

View file

@ -22,6 +22,7 @@ import logging
from telegram import Update
from telegram.ext import Handler
from telegram.utils.helpers import extract_chat_and_user
from telegram.utils.promise import Promise
@ -115,29 +116,7 @@ class ConversationHandler(Handler):
if not isinstance(update, Update):
return False
user = None
chat = None
if update.message:
user = update.message.from_user
chat = update.message.chat
elif update.edited_message:
user = update.edited_message.from_user
chat = update.edited_message.chat
elif update.inline_query:
user = update.inline_query.from_user
elif update.chosen_inline_result:
user = update.chosen_inline_result.from_user
elif update.callback_query:
user = update.callback_query.from_user
chat = update.callback_query.message.chat if update.callback_query.message else None
else:
return False
chat, user = extract_chat_and_user(update)
key = (chat.id, user.id) if chat else (None, user.id)
state = self.conversations.get(key)

View file

@ -24,6 +24,7 @@ from functools import wraps
from threading import Thread, Lock, Event, current_thread, BoundedSemaphore
from time import sleep
from uuid import uuid4
from collections import defaultdict
from queue import Queue, Empty
@ -86,6 +87,11 @@ class Dispatcher(object):
self.job_queue = job_queue
self.workers = workers
self.user_data = defaultdict(dict)
""":type: dict[int, dict]"""
self.chat_data = defaultdict(dict)
""":type: dict[int, dict]"""
self.handlers = {}
""":type: dict[int, list[Handler]"""
self.groups = []

View file

@ -20,6 +20,7 @@
Dispatcher """
from telegram.utils.deprecate import deprecate
from telegram.utils.helpers import extract_chat_and_user
class Handler(object):
@ -39,12 +40,27 @@ class Handler(object):
``job_queue`` will be passed to the callback function. It will be a ``JobQueue``
instance created by the ``Updater`` which can be used to schedule new jobs.
Default is ``False``.
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self, callback, pass_update_queue=False, pass_job_queue=False):
def __init__(self,
callback,
pass_update_queue=False,
pass_job_queue=False,
pass_user_data=False,
pass_chat_data=False):
self.callback = callback
self.pass_update_queue = pass_update_queue
self.pass_job_queue = pass_job_queue
self.pass_user_data = pass_user_data
self.pass_chat_data = pass_chat_data
def check_update(self, update):
"""
@ -74,7 +90,7 @@ class Handler(object):
"""
raise NotImplementedError
def collect_optional_args(self, dispatcher):
def collect_optional_args(self, dispatcher, update=None):
"""
Prepares the optional arguments that are the same for all types of
handlers
@ -83,10 +99,19 @@ class Handler(object):
dispatcher (Dispatcher):
"""
optional_args = dict()
if self.pass_update_queue:
optional_args['update_queue'] = dispatcher.update_queue
if self.pass_job_queue:
optional_args['job_queue'] = dispatcher.job_queue
if self.pass_user_data or self.pass_chat_data:
chat, user = extract_chat_and_user(update)
if self.pass_user_data:
optional_args['user_data'] = dispatcher.user_data[user.id]
if self.pass_chat_data:
optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None]
return optional_args

View file

@ -51,6 +51,14 @@ class InlineQueryHandler(Handler):
pass_groupdict (optional[bool]): If the callback should be passed the
result of ``re.match(pattern, query).groupdict()`` as a keyword
argument called ``groupdict``. Default is ``False``
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self,
@ -59,9 +67,15 @@ class InlineQueryHandler(Handler):
pass_job_queue=False,
pattern=None,
pass_groups=False,
pass_groupdict=False):
pass_groupdict=False,
pass_user_data=False,
pass_chat_data=False):
super(InlineQueryHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
if isinstance(pattern, string_types):
pattern = re.compile(pattern)
@ -80,7 +94,7 @@ class InlineQueryHandler(Handler):
return True
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
if self.pattern:
match = re.match(self.pattern, update.inline_query.query)

View file

@ -43,6 +43,14 @@ class MessageHandler(Handler):
pass_update_queue (optional[bool]): If the handler should be passed the
update queue as a keyword argument called ``update_queue``. It can
be used to insert updates. Default is ``False``
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self,
@ -50,9 +58,15 @@ class MessageHandler(Handler):
callback,
allow_edited=False,
pass_update_queue=False,
pass_job_queue=False):
pass_job_queue=False,
pass_user_data=False,
pass_chat_data=False):
super(MessageHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
self.filters = filters
self.allow_edited = allow_edited
@ -83,7 +97,7 @@ class MessageHandler(Handler):
return res
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
return self.callback(dispatcher.bot, update, **optional_args)

View file

@ -53,6 +53,14 @@ class RegexHandler(Handler):
``job_queue`` will be passed to the callback function. It will be a ``JobQueue``
instance created by the ``Updater`` which can be used to schedule new jobs.
Default is ``False``.
pass_user_data (optional[bool]): If set to ``True``, a keyword argument called
``user_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the user that sent the update. For each update of
the same user, it will be the same ``dict``. Default is ``False``.
pass_chat_data (optional[bool]): If set to ``True``, a keyword argument called
``chat_data`` will be passed to the callback function. It will be a ``dict`` you
can use to keep any data related to the chat that the update was sent in.
For each update in the same chat, it will be the same ``dict``. Default is ``False``.
"""
def __init__(self,
@ -61,9 +69,15 @@ class RegexHandler(Handler):
pass_groups=False,
pass_groupdict=False,
pass_update_queue=False,
pass_job_queue=False):
pass_job_queue=False,
pass_user_data=False,
pass_chat_data=False):
super(RegexHandler, self).__init__(
callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue)
callback,
pass_update_queue=pass_update_queue,
pass_job_queue=pass_job_queue,
pass_user_data=pass_user_data,
pass_chat_data=pass_chat_data)
if isinstance(pattern, string_types):
pattern = re.compile(pattern)
@ -73,14 +87,14 @@ class RegexHandler(Handler):
self.pass_groupdict = pass_groupdict
def check_update(self, update):
if (isinstance(update, Update) and update.message and update.message.text):
if isinstance(update, Update) and update.message and update.message.text:
match = re.match(self.pattern, update.message.text)
return bool(match)
else:
return False
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
optional_args = self.collect_optional_args(dispatcher, update)
match = re.match(self.pattern, update.message.text)
if self.pass_groups:

44
telegram/utils/helpers.py Normal file
View file

@ -0,0 +1,44 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2016
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
""" This module contains helper functions """
def extract_chat_and_user(update):
user = None
chat = None
if update.message:
user = update.message.from_user
chat = update.message.chat
elif update.edited_message:
user = update.edited_message.from_user
chat = update.edited_message.chat
elif update.inline_query:
user = update.inline_query.from_user
elif update.chosen_inline_result:
user = update.chosen_inline_result.from_user
elif update.callback_query:
user = update.callback_query.from_user
chat = update.callback_query.message.chat if update.callback_query.message else None
return chat, user

View file

@ -145,6 +145,12 @@ class UpdaterTest(BaseTest, unittest.TestCase):
elif args[0] == 'noresend':
pass
def userAndChatDataTest(self, bot, update, user_data, chat_data):
user_data['text'] = update.message.text
chat_data['text'] = update.message.text
self.received_message = update.message.text
self.message_count += 1
@run_async
def asyncAdditionalHandlerTest(self, bot, update, update_queue=None):
sleep(1)
@ -483,6 +489,19 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.assertEqual(self.received_message, '/test5 noresend')
self.assertEqual(self.message_count, 2)
def test_user_and_chat_data(self):
self._setup_updater('/test_data', messages=1)
handler = CommandHandler(
'test_data', self.userAndChatDataTest, pass_user_data=True, pass_chat_data=True)
self.updater.dispatcher.add_handler(handler)
self.updater.start_polling(0.01)
sleep(.1)
self.assertEqual(self.received_message, '/test_data')
self.assertEqual(self.message_count, 1)
self.assertDictEqual(dict(self.updater.dispatcher.user_data), {0: {'text': '/test_data'}})
self.assertDictEqual(dict(self.updater.dispatcher.chat_data), {0: {'text': '/test_data'}})
def test_regexGroupHandler(self):
self._setup_updater('', messages=0)
d = self.updater.dispatcher
@ -771,7 +790,7 @@ class MockBot(object):
self.edited = edited
def mockUpdate(self, text):
message = Message(0, None, None, None)
message = Message(0, User(0, 'Testuser'), None, Chat(0, Chat.GROUP))
message.text = text
update = Update(0)