diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index 0119e52f0..36a2901b0 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -24,7 +24,7 @@ from threading import Lock from telegram import Update from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler, - ChosenInlineResultHandler, CallbackContext) + ChosenInlineResultHandler, CallbackContext, DispatcherHandlerStop) from telegram.utils.promise import Promise @@ -454,6 +454,7 @@ class ConversationHandler(Handler): """ conversation_key, handler, check_result = check_result + raise_dp_handler_stop = False with self._timeout_jobs_lock: # Remove the old timeout job (if present) @@ -462,7 +463,11 @@ class ConversationHandler(Handler): if timeout_job is not None: timeout_job.schedule_removal() - new_state = handler.handle_update(update, dispatcher, check_result, context) + try: + new_state = handler.handle_update(update, dispatcher, check_result, context) + except DispatcherHandlerStop as e: + new_state = e.state + raise_dp_handler_stop = True with self._timeout_jobs_lock: if self.conversation_timeout and new_state != self.END: @@ -474,9 +479,16 @@ class ConversationHandler(Handler): if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: self.update_state(self.END, conversation_key) - return self.map_to_parent.get(new_state) + if raise_dp_handler_stop: + raise DispatcherHandlerStop(self.map_to_parent.get(new_state)) + else: + return self.map_to_parent.get(new_state) else: self.update_state(new_state, conversation_key) + if raise_dp_handler_stop: + # Don't pass the new state here. If we're in a nested conversation, the parent is + # expecting None as return value. + raise DispatcherHandlerStop() def update_state(self, new_state, key): if new_state == self.END: @@ -522,5 +534,10 @@ class ConversationHandler(Handler): for handler in handlers: check = handler.check_update(context.update) if check is not None and check is not False: - handler.handle_update(context.update, context.dispatcher, check, callback_context) + try: + handler.handle_update(context.update, context.dispatcher, check, + callback_context) + except DispatcherHandlerStop: + self.logger.warning('DispatcherHandlerStop in TIMEOUT state of ' + 'ConversationHandler has no effect. Ignoring.') self.update_state(self.END, context.conversation_key) diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index 5dd61ed28..78e0daee4 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -60,8 +60,27 @@ def run_async(func): class DispatcherHandlerStop(Exception): - """Raise this in handler to prevent execution any other handler (even in different group).""" - pass + """ + Raise this in handler to prevent execution any other handler (even in different group). + + In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the + optional ``state`` parameter instead of returning the next state: + + .. code-block:: python + + def callback(update, context): + ... + raise DispatcherHandlerStop(next_state) + + Attributes: + state (:obj:`object`): Optional. The next state of the conversation. + + Args: + state (:obj:`object`, optional): The next state of the conversation. + """ + def __init__(self, state=None): + super().__init__() + self.state = state class Dispatcher: diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index 0bdbc4eb4..4d452f3d0 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -24,7 +24,8 @@ import pytest from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message, PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity) from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler, - MessageHandler, Filters, InlineQueryHandler, CallbackContext) + MessageHandler, Filters, InlineQueryHandler, CallbackContext, + DispatcherHandlerStop, TypeHandler) @pytest.fixture(scope='class') @@ -37,6 +38,17 @@ def user2(): return User(first_name='Mister Test', id=124, is_bot=False) +def raise_dphs(func): + def decorator(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if self.raise_dp_handler_stop: + raise DispatcherHandlerStop(result) + else: + return result + + return decorator + + class TestConversationHandler: # State definitions # At first we're thirsty. Then we brew coffee, we drink it @@ -51,9 +63,14 @@ class TestConversationHandler: group = Chat(0, Chat.GROUP) second_group = Chat(1, Chat.GROUP) + raise_dp_handler_stop = False + test_flag = False + # Test related @pytest.fixture(autouse=True) def reset(self): + self.raise_dp_handler_stop = False + self.test_flag = False self.current_state = dict() self.entry_points = [CommandHandler('start', self.start)] self.states = { @@ -116,65 +133,81 @@ class TestConversationHandler: return state # Actions + @raise_dphs def start(self, bot, update): if isinstance(update, Update): return self._set_state(update, self.THIRSTY) else: return self._set_state(bot, self.THIRSTY) + @raise_dphs def end(self, bot, update): return self._set_state(update, self.END) + @raise_dphs def start_end(self, bot, update): return self._set_state(update, self.END) + @raise_dphs def start_none(self, bot, update): return self._set_state(update, None) + @raise_dphs def brew(self, bot, update): if isinstance(update, Update): return self._set_state(update, self.BREWING) else: return self._set_state(bot, self.BREWING) + @raise_dphs def drink(self, bot, update): return self._set_state(update, self.DRINKING) + @raise_dphs def code(self, bot, update): return self._set_state(update, self.CODING) + @raise_dphs def passout(self, bot, update): assert update.message.text == '/brew' assert isinstance(update, Update) self.is_timeout = True + @raise_dphs def passout2(self, bot, update): assert isinstance(update, Update) self.is_timeout = True + @raise_dphs def passout_context(self, update, context): assert update.message.text == '/brew' assert isinstance(context, CallbackContext) self.is_timeout = True + @raise_dphs def passout2_context(self, update, context): assert isinstance(context, CallbackContext) self.is_timeout = True # Drinking actions (nested) + @raise_dphs def hold(self, bot, update): return self._set_state(update, self.HOLDING) + @raise_dphs def sip(self, bot, update): return self._set_state(update, self.SIPPING) + @raise_dphs def swallow(self, bot, update): return self._set_state(update, self.SWALLOWING) + @raise_dphs def replenish(self, bot, update): return self._set_state(update, self.REPLENISHING) + @raise_dphs def stop(self, bot, update): return self._set_state(update, self.STOPPING) @@ -546,6 +579,32 @@ class TestConversationHandler: dp.job_queue.tick() assert handler.conversations.get((self.group.id, user1.id)) is None + def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog): + handler = ConversationHandler(entry_points=self.entry_points, states=self.states, + fallbacks=self.fallbacks, conversation_timeout=0.5) + + def timeout(*args, **kwargs): + raise DispatcherHandlerStop() + + self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) + dp.add_handler(handler) + + # Start state machine, then reach timeout + message = Message(0, user1, None, self.group, text='/start', + entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, + offset=0, length=len('/start'))], + bot=bot) + + with caplog.at_level(logging.WARNING): + dp.process_update(Update(update_id=0, message=message)) + assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY + sleep(0.5) + dp.job_queue.tick() + assert handler.conversations.get((self.group.id, user1.id)) is None + assert len(caplog.records) == 1 + rec = caplog.records[-1] + assert rec.msg.startswith('DispatcherHandlerStop in TIMEOUT') + def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1): context = None @@ -953,3 +1012,129 @@ class TestConversationHandler: dp.process_update(Update(update_id=0, message=message)) assert self.current_state[user1.id] == self.STOPPING assert handler.conversations.get((0, user1.id)) is None + + def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2): + self.nested_states[self.DRINKING] = [ConversationHandler( + entry_points=self.drinking_entry_points, + states=self.drinking_states, + fallbacks=self.drinking_fallbacks, + map_to_parent=self.drinking_map_to_parent)] + handler = ConversationHandler(entry_points=self.entry_points, + states=self.nested_states, + fallbacks=self.fallbacks) + + def test_callback(u, c): + self.test_flag = True + + dp.add_handler(handler) + dp.add_handler(TypeHandler(Update, test_callback), group=1) + self.raise_dp_handler_stop = True + + # User one, starts the state machine. + message = Message(0, user1, None, self.group, text='/start', bot=bot, + entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, + offset=0, length=len('/start'))]) + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + assert not self.test_flag + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + assert not self.test_flag + + # Lets pour some coffee. + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is holding the cup + message.text = '/hold' + message.entities[0].length = len('/hold') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user is sipping coffee + message.text = '/sip' + message.entities[0].length = len('/sip') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SIPPING + assert not self.test_flag + + # The user is swallowing + message.text = '/swallow' + message.entities[0].length = len('/swallow') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SWALLOWING + assert not self.test_flag + + # The user is holding the cup again + message.text = '/hold' + message.entities[0].length = len('/hold') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user wants to replenish the coffee supply + message.text = '/replenish' + message.entities[0].length = len('/replenish') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.REPLENISHING + assert handler.conversations[(0, user1.id)] == self.BREWING + assert not self.test_flag + + # The user wants to drink their coffee again + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is now ready to start coding + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.CODING + assert not self.test_flag + + # The user decides it's time to drink again + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is holding their cup + message.text = '/hold' + message.entities[0].length = len('/hold') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user wants to end with the drinking and go back to coding + message.text = '/end' + message.entities[0].length = len('/end') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.END + assert handler.conversations[(0, user1.id)] == self.CODING + assert not self.test_flag + + # The user wants to drink once more + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user wants to stop altogether + message.text = '/stop' + message.entities[0].length = len('/stop') + dp.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.STOPPING + assert handler.conversations.get((0, user1.id)) is None + assert not self.test_flag