Allow DispatcherHandlerStop in ConversationHandler (#2059)

* First go

* Fix bug with nested convs
This commit is contained in:
Bibo-Joshi 2020-08-21 23:20:28 +02:00 committed by GitHub
parent 3304cc5c90
commit da452df07d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 228 additions and 7 deletions

View file

@ -24,7 +24,7 @@ from threading import Lock
from telegram import Update from telegram import Update
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler, from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
ChosenInlineResultHandler, CallbackContext) ChosenInlineResultHandler, CallbackContext, DispatcherHandlerStop)
from telegram.utils.promise import Promise from telegram.utils.promise import Promise
@ -454,6 +454,7 @@ class ConversationHandler(Handler):
""" """
conversation_key, handler, check_result = check_result conversation_key, handler, check_result = check_result
raise_dp_handler_stop = False
with self._timeout_jobs_lock: with self._timeout_jobs_lock:
# Remove the old timeout job (if present) # Remove the old timeout job (if present)
@ -462,7 +463,11 @@ class ConversationHandler(Handler):
if timeout_job is not None: if timeout_job is not None:
timeout_job.schedule_removal() timeout_job.schedule_removal()
new_state = handler.handle_update(update, dispatcher, check_result, context) try:
new_state = handler.handle_update(update, dispatcher, check_result, context)
except DispatcherHandlerStop as e:
new_state = e.state
raise_dp_handler_stop = True
with self._timeout_jobs_lock: with self._timeout_jobs_lock:
if self.conversation_timeout and new_state != self.END: if self.conversation_timeout and new_state != self.END:
@ -474,9 +479,16 @@ class ConversationHandler(Handler):
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
self.update_state(self.END, conversation_key) self.update_state(self.END, conversation_key)
return self.map_to_parent.get(new_state) if raise_dp_handler_stop:
raise DispatcherHandlerStop(self.map_to_parent.get(new_state))
else:
return self.map_to_parent.get(new_state)
else: else:
self.update_state(new_state, conversation_key) self.update_state(new_state, conversation_key)
if raise_dp_handler_stop:
# Don't pass the new state here. If we're in a nested conversation, the parent is
# expecting None as return value.
raise DispatcherHandlerStop()
def update_state(self, new_state, key): def update_state(self, new_state, key):
if new_state == self.END: if new_state == self.END:
@ -522,5 +534,10 @@ class ConversationHandler(Handler):
for handler in handlers: for handler in handlers:
check = handler.check_update(context.update) check = handler.check_update(context.update)
if check is not None and check is not False: if check is not None and check is not False:
handler.handle_update(context.update, context.dispatcher, check, callback_context) try:
handler.handle_update(context.update, context.dispatcher, check,
callback_context)
except DispatcherHandlerStop:
self.logger.warning('DispatcherHandlerStop in TIMEOUT state of '
'ConversationHandler has no effect. Ignoring.')
self.update_state(self.END, context.conversation_key) self.update_state(self.END, context.conversation_key)

View file

@ -60,8 +60,27 @@ def run_async(func):
class DispatcherHandlerStop(Exception): class DispatcherHandlerStop(Exception):
"""Raise this in handler to prevent execution any other handler (even in different group).""" """
pass Raise this in handler to prevent execution any other handler (even in different group).
In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the
optional ``state`` parameter instead of returning the next state:
.. code-block:: python
def callback(update, context):
...
raise DispatcherHandlerStop(next_state)
Attributes:
state (:obj:`object`): Optional. The next state of the conversation.
Args:
state (:obj:`object`, optional): The next state of the conversation.
"""
def __init__(self, state=None):
super().__init__()
self.state = state
class Dispatcher: class Dispatcher:

View file

@ -24,7 +24,8 @@ import pytest
from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message, from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message,
PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity) PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity)
from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler, from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler,
MessageHandler, Filters, InlineQueryHandler, CallbackContext) MessageHandler, Filters, InlineQueryHandler, CallbackContext,
DispatcherHandlerStop, TypeHandler)
@pytest.fixture(scope='class') @pytest.fixture(scope='class')
@ -37,6 +38,17 @@ def user2():
return User(first_name='Mister Test', id=124, is_bot=False) return User(first_name='Mister Test', id=124, is_bot=False)
def raise_dphs(func):
def decorator(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if self.raise_dp_handler_stop:
raise DispatcherHandlerStop(result)
else:
return result
return decorator
class TestConversationHandler: class TestConversationHandler:
# State definitions # State definitions
# At first we're thirsty. Then we brew coffee, we drink it # At first we're thirsty. Then we brew coffee, we drink it
@ -51,9 +63,14 @@ class TestConversationHandler:
group = Chat(0, Chat.GROUP) group = Chat(0, Chat.GROUP)
second_group = Chat(1, Chat.GROUP) second_group = Chat(1, Chat.GROUP)
raise_dp_handler_stop = False
test_flag = False
# Test related # Test related
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset(self): def reset(self):
self.raise_dp_handler_stop = False
self.test_flag = False
self.current_state = dict() self.current_state = dict()
self.entry_points = [CommandHandler('start', self.start)] self.entry_points = [CommandHandler('start', self.start)]
self.states = { self.states = {
@ -116,65 +133,81 @@ class TestConversationHandler:
return state return state
# Actions # Actions
@raise_dphs
def start(self, bot, update): def start(self, bot, update):
if isinstance(update, Update): if isinstance(update, Update):
return self._set_state(update, self.THIRSTY) return self._set_state(update, self.THIRSTY)
else: else:
return self._set_state(bot, self.THIRSTY) return self._set_state(bot, self.THIRSTY)
@raise_dphs
def end(self, bot, update): def end(self, bot, update):
return self._set_state(update, self.END) return self._set_state(update, self.END)
@raise_dphs
def start_end(self, bot, update): def start_end(self, bot, update):
return self._set_state(update, self.END) return self._set_state(update, self.END)
@raise_dphs
def start_none(self, bot, update): def start_none(self, bot, update):
return self._set_state(update, None) return self._set_state(update, None)
@raise_dphs
def brew(self, bot, update): def brew(self, bot, update):
if isinstance(update, Update): if isinstance(update, Update):
return self._set_state(update, self.BREWING) return self._set_state(update, self.BREWING)
else: else:
return self._set_state(bot, self.BREWING) return self._set_state(bot, self.BREWING)
@raise_dphs
def drink(self, bot, update): def drink(self, bot, update):
return self._set_state(update, self.DRINKING) return self._set_state(update, self.DRINKING)
@raise_dphs
def code(self, bot, update): def code(self, bot, update):
return self._set_state(update, self.CODING) return self._set_state(update, self.CODING)
@raise_dphs
def passout(self, bot, update): def passout(self, bot, update):
assert update.message.text == '/brew' assert update.message.text == '/brew'
assert isinstance(update, Update) assert isinstance(update, Update)
self.is_timeout = True self.is_timeout = True
@raise_dphs
def passout2(self, bot, update): def passout2(self, bot, update):
assert isinstance(update, Update) assert isinstance(update, Update)
self.is_timeout = True self.is_timeout = True
@raise_dphs
def passout_context(self, update, context): def passout_context(self, update, context):
assert update.message.text == '/brew' assert update.message.text == '/brew'
assert isinstance(context, CallbackContext) assert isinstance(context, CallbackContext)
self.is_timeout = True self.is_timeout = True
@raise_dphs
def passout2_context(self, update, context): def passout2_context(self, update, context):
assert isinstance(context, CallbackContext) assert isinstance(context, CallbackContext)
self.is_timeout = True self.is_timeout = True
# Drinking actions (nested) # Drinking actions (nested)
@raise_dphs
def hold(self, bot, update): def hold(self, bot, update):
return self._set_state(update, self.HOLDING) return self._set_state(update, self.HOLDING)
@raise_dphs
def sip(self, bot, update): def sip(self, bot, update):
return self._set_state(update, self.SIPPING) return self._set_state(update, self.SIPPING)
@raise_dphs
def swallow(self, bot, update): def swallow(self, bot, update):
return self._set_state(update, self.SWALLOWING) return self._set_state(update, self.SWALLOWING)
@raise_dphs
def replenish(self, bot, update): def replenish(self, bot, update):
return self._set_state(update, self.REPLENISHING) return self._set_state(update, self.REPLENISHING)
@raise_dphs
def stop(self, bot, update): def stop(self, bot, update):
return self._set_state(update, self.STOPPING) return self._set_state(update, self.STOPPING)
@ -546,6 +579,32 @@ class TestConversationHandler:
dp.job_queue.tick() dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None assert handler.conversations.get((self.group.id, user1.id)) is None
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
fallbacks=self.fallbacks, conversation_timeout=0.5)
def timeout(*args, **kwargs):
raise DispatcherHandlerStop()
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
dp.add_handler(handler)
# Start state machine, then reach timeout
message = Message(0, user1, None, self.group, text='/start',
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
offset=0, length=len('/start'))],
bot=bot)
with caplog.at_level(logging.WARNING):
dp.process_update(Update(update_id=0, message=message))
assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY
sleep(0.5)
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None
assert len(caplog.records) == 1
rec = caplog.records[-1]
assert rec.msg.startswith('DispatcherHandlerStop in TIMEOUT')
def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1): def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1):
context = None context = None
@ -953,3 +1012,129 @@ class TestConversationHandler:
dp.process_update(Update(update_id=0, message=message)) dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.STOPPING assert self.current_state[user1.id] == self.STOPPING
assert handler.conversations.get((0, user1.id)) is None assert handler.conversations.get((0, user1.id)) is None
def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2):
self.nested_states[self.DRINKING] = [ConversationHandler(
entry_points=self.drinking_entry_points,
states=self.drinking_states,
fallbacks=self.drinking_fallbacks,
map_to_parent=self.drinking_map_to_parent)]
handler = ConversationHandler(entry_points=self.entry_points,
states=self.nested_states,
fallbacks=self.fallbacks)
def test_callback(u, c):
self.test_flag = True
dp.add_handler(handler)
dp.add_handler(TypeHandler(Update, test_callback), group=1)
self.raise_dp_handler_stop = True
# User one, starts the state machine.
message = Message(0, user1, None, self.group, text='/start', bot=bot,
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
offset=0, length=len('/start'))])
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.THIRSTY
assert not self.test_flag
# The user is thirsty and wants to brew coffee.
message.text = '/brew'
message.entities[0].length = len('/brew')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.BREWING
assert not self.test_flag
# Lets pour some coffee.
message.text = '/pourCoffee'
message.entities[0].length = len('/pourCoffee')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is holding the cup
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user is sipping coffee
message.text = '/sip'
message.entities[0].length = len('/sip')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.SIPPING
assert not self.test_flag
# The user is swallowing
message.text = '/swallow'
message.entities[0].length = len('/swallow')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.SWALLOWING
assert not self.test_flag
# The user is holding the cup again
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user wants to replenish the coffee supply
message.text = '/replenish'
message.entities[0].length = len('/replenish')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.REPLENISHING
assert handler.conversations[(0, user1.id)] == self.BREWING
assert not self.test_flag
# The user wants to drink their coffee again
message.text = '/pourCoffee'
message.entities[0].length = len('/pourCoffee')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is now ready to start coding
message.text = '/startCoding'
message.entities[0].length = len('/startCoding')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.CODING
assert not self.test_flag
# The user decides it's time to drink again
message.text = '/drinkMore'
message.entities[0].length = len('/drinkMore')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is holding their cup
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user wants to end with the drinking and go back to coding
message.text = '/end'
message.entities[0].length = len('/end')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.END
assert handler.conversations[(0, user1.id)] == self.CODING
assert not self.test_flag
# The user wants to drink once more
message.text = '/drinkMore'
message.entities[0].length = len('/drinkMore')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user wants to stop altogether
message.text = '/stop'
message.entities[0].length = len('/stop')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.STOPPING
assert handler.conversations.get((0, user1.id)) is None
assert not self.test_flag