diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index 1611f8483..3a60570bd 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -21,7 +21,8 @@ import logging from telegram import Update -from telegram.ext import Handler +from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler, + ChosenInlineResultHandler) from telegram.utils.promise import Promise @@ -87,7 +88,10 @@ class ConversationHandler(Handler): fallbacks, allow_reentry=False, run_async_timeout=None, - timed_out_behavior=None): + timed_out_behavior=None, + per_chat=True, + per_user=True, + per_message=False): self.entry_points = entry_points """:type: list[telegram.ext.Handler]""" @@ -105,22 +109,66 @@ class ConversationHandler(Handler): """:type: list[telegram.ext.Handler]""" self.conversations = dict() - """:type: dict[(int, int): str]""" + self.per_user = per_user + self.per_chat = per_chat + self.per_message = per_message + """:type: dict[tuple: object]""" self.current_conversation = None self.current_handler = None self.logger = logging.getLogger(__name__) + if not any((self.per_user, self.per_chat, self.per_message)): + raise ValueError("'per_user', 'per_chat' and 'per_message' can't all be 'False'") + + all_handlers = list() + all_handlers.extend(entry_points) + all_handlers.extend(fallbacks) + + for state_handlers in states.values(): + all_handlers.extend(state_handlers) + + if self.per_message: + for handler in all_handlers: + if not isinstance(handler, CallbackQueryHandler): + raise ValueError("If 'per_message=True', all entry points and state handlers" + " must be 'CallbackQueryHandler'") + else: + for handler in all_handlers: + if isinstance(handler, CallbackQueryHandler): + raise ValueError("If 'per_message=False', 'CallbackQueryHandler' doesn't work") + + if self.per_chat: + for handler in all_handlers: + if isinstance(handler, (InlineQueryHandler, ChosenInlineResultHandler)): + raise ValueError("If 'per_chat=True', 'InlineQueryHandler' doesn't work") + + def _get_key(self, update): + chat, user = update.extract_chat_and_user() + key = list() + + if self.per_chat: + key.append(chat.id) + + if self.per_user: + key.append(user.id) + + if self.per_message: + key.append(update.callback_query.inline_message_id + or update.callback_query.message.message_id) + + return tuple(key) + def check_update(self, update): # Ignore messages in channels - if not isinstance(update, Update) or update.channel_post: + if (not isinstance(update, Update) or update.channel_post or self.per_chat + and (update.inline_query or update.chosen_inline_result) or self.per_message + and not update.callback_query): return False - chat, user = update.extract_chat_and_user() - - key = (chat.id, user.id) if chat else (None, user.id) + key = self._get_key(update) state = self.conversations.get(key) # Resolve promises diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index 655e7a866..f62109049 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -35,8 +35,9 @@ except ImportError: sys.path.append('.') -from telegram import Update, Message, TelegramError, User, Chat, Bot -from telegram.ext import Updater, ConversationHandler, CommandHandler +from telegram import Update, Message, TelegramError, User, Chat, Bot, CallbackQuery +from telegram.ext import (Updater, ConversationHandler, CommandHandler, CallbackQueryHandler, + InlineQueryHandler) from tests.base import BaseTest from tests.test_updater import MockBot @@ -79,6 +80,12 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): } self.fallbacks = [CommandHandler('eat', self.start)] + self.group = Chat(0, Chat.GROUP) + self.second_group = Chat(1, Chat.GROUP) + + def _chat(self, user): + return Chat(user.id, Chat.GROUP) + def _setup_updater(self, *args, **kwargs): self.bot = MockBot(*args, **kwargs) self.updater = Updater(workers=2, bot=self.bot) @@ -137,36 +144,164 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): queue = self.updater.start_polling(0.01) # User one, starts the state machine. - message = Message(0, user, None, None, text="/start", bot=self.bot) + message = Message(0, user, None, self.group, text="/start", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertTrue(self.current_state[user.id] == self.THIRSTY) # The user is thirsty and wants to brew coffee. - message = Message(0, user, None, None, text="/brew", bot=self.bot) + message = Message(0, user, None, self.group, text="/brew", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertTrue(self.current_state[user.id] == self.BREWING) # Lets see if an invalid command makes sure, no state is changed. - message = Message(0, user, None, None, text="/nothing", bot=self.bot) + message = Message(0, user, None, self.group, text="/nothing", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertTrue(self.current_state[user.id] == self.BREWING) # Lets see if the state machine still works by pouring coffee. - message = Message(0, user, None, None, text="/pourCoffee", bot=self.bot) + message = Message(0, user, None, self.group, text="/pourCoffee", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertTrue(self.current_state[user.id] == self.DRINKING) # Let's now verify that for another user, who did not start yet, # the state has not been changed. - message = Message(0, second_user, None, None, text="/brew", bot=self.bot) + message = Message(0, second_user, None, self.group, text="/brew", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertRaises(KeyError, self._get_state, user_id=second_user.id) + def test_addConversationHandlerPerChat(self): + self._setup_updater('', messages=0) + d = self.updater.dispatcher + user = User(first_name="Misses Test", id=123) + second_user = User(first_name="Mister Test", id=124) + + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + per_user=False) + d.add_handler(handler) + queue = self.updater.start_polling(0.01) + + # User one, starts the state machine. + message = Message(0, user, None, self.group, text="/start", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + + # The user is thirsty and wants to brew coffee. + message = Message(0, user, None, self.group, text="/brew", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + + # Let's now verify that for another user, who did not start yet, + # the state will be changed because they are in the same group. + message = Message(0, second_user, None, self.group, text="/pourCoffee", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + self.assertEquals(handler.conversations[(self.group.id,)], self.DRINKING) + + def test_addConversationHandlerPerUser(self): + self._setup_updater('', messages=0) + d = self.updater.dispatcher + user = User(first_name="Misses Test", id=123) + + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + per_chat=False) + d.add_handler(handler) + queue = self.updater.start_polling(0.01) + + # User one, starts the state machine. + message = Message(0, user, None, self.group, text="/start", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + + # The user is thirsty and wants to brew coffee. + message = Message(0, user, None, self.group, text="/brew", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + + # Let's now verify that for the same user in a different group, the state will still be + # updated + message = Message(0, user, None, self.second_group, text="/pourCoffee", bot=self.bot) + queue.put(Update(update_id=0, message=message)) + sleep(.1) + + self.assertEquals(handler.conversations[(user.id,)], self.DRINKING) + + def test_addConversationHandlerPerMessage(self): + self._setup_updater('', messages=0) + d = self.updater.dispatcher + user = User(first_name="Misses Test", id=123) + second_user = User(first_name="Mister Test", id=124) + + def entry(bot, update): + return 1 + + def one(bot, update): + return 2 + + def two(bot, update): + return ConversationHandler.END + + handler = ConversationHandler( + entry_points=[CallbackQueryHandler(entry)], + states={1: [CallbackQueryHandler(one)], + 2: [CallbackQueryHandler(two)]}, + fallbacks=[], + per_message=True) + d.add_handler(handler) + queue = self.updater.start_polling(0.01) + + # User one, starts the state machine. + message = Message(0, user, None, self.group, text="msg w/ inlinekeyboard", bot=self.bot) + + cbq = CallbackQuery(0, user, None, message=message, data='data', bot=self.bot) + queue.put(Update(update_id=0, callback_query=cbq)) + sleep(.1) + self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 1) + + cbq = CallbackQuery(0, user, None, message=message, data='data', bot=self.bot) + queue.put(Update(update_id=0, callback_query=cbq)) + sleep(.1) + self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 2) + + # Let's now verify that for a different user in the same group, the state will not be + # updated + cbq = CallbackQuery(0, second_user, None, message=message, data='data', bot=self.bot) + queue.put(Update(update_id=0, callback_query=cbq)) + sleep(.1) + self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 2) + + def test_illegal_handlers(self): + with self.assertRaises(ValueError): + ConversationHandler( + entry_points=[CommandHandler('/test', lambda bot, update: None)], + states={}, + fallbacks=[], + per_message=True) + + with self.assertRaises(ValueError): + ConversationHandler( + entry_points=[CallbackQueryHandler(lambda bot, update: None)], + states={}, + fallbacks=[], + per_message=False) + + with self.assertRaises(ValueError): + ConversationHandler( + entry_points=[InlineQueryHandler(lambda bot, update: None)], + states={}, + fallbacks=[], + per_chat=True) + def test_endOnFirstMessage(self): self._setup_updater('', messages=0) d = self.updater.dispatcher @@ -178,7 +313,7 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): queue = self.updater.start_polling(0.01) # User starts the state machine and immediately ends it. - message = Message(0, user, None, None, text="/start") + message = Message(0, user, None, self.group, text="/start", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) self.assertEquals(len(handler.conversations), 0) @@ -197,13 +332,13 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): # User starts the state machine with an async function that immediately ends the # conversation. Async results are resolved when the users state is queried next time. - message = Message(0, user, None, None, text="/start", bot=self.bot) + message = Message(0, user, None, self.group, text="/start", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) # Assert that the Promise has been accepted as the new state self.assertEquals(len(handler.conversations), 1) - message = Message(0, user, None, None, text="resolve promise pls", bot=self.bot) + message = Message(0, user, None, self.group, text="resolve promise pls", bot=self.bot) queue.put(Update(update_id=0, message=message)) sleep(.1) # Assert that the Promise has been resolved and the conversation ended.