Allow DispatcherHandlerStop in ConversationHandler (#2059)

* First go

* Fix bug with nested convs
This commit is contained in:
Bibo-Joshi 2020-08-21 23:20:28 +02:00 committed by GitHub
parent 3304cc5c90
commit da452df07d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 228 additions and 7 deletions

View file

@ -24,7 +24,7 @@ from threading import Lock
from telegram import Update
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
ChosenInlineResultHandler, CallbackContext)
ChosenInlineResultHandler, CallbackContext, DispatcherHandlerStop)
from telegram.utils.promise import Promise
@ -454,6 +454,7 @@ class ConversationHandler(Handler):
"""
conversation_key, handler, check_result = check_result
raise_dp_handler_stop = False
with self._timeout_jobs_lock:
# Remove the old timeout job (if present)
@ -462,7 +463,11 @@ class ConversationHandler(Handler):
if timeout_job is not None:
timeout_job.schedule_removal()
new_state = handler.handle_update(update, dispatcher, check_result, context)
try:
new_state = handler.handle_update(update, dispatcher, check_result, context)
except DispatcherHandlerStop as e:
new_state = e.state
raise_dp_handler_stop = True
with self._timeout_jobs_lock:
if self.conversation_timeout and new_state != self.END:
@ -474,9 +479,16 @@ class ConversationHandler(Handler):
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
self.update_state(self.END, conversation_key)
return self.map_to_parent.get(new_state)
if raise_dp_handler_stop:
raise DispatcherHandlerStop(self.map_to_parent.get(new_state))
else:
return self.map_to_parent.get(new_state)
else:
self.update_state(new_state, conversation_key)
if raise_dp_handler_stop:
# Don't pass the new state here. If we're in a nested conversation, the parent is
# expecting None as return value.
raise DispatcherHandlerStop()
def update_state(self, new_state, key):
if new_state == self.END:
@ -522,5 +534,10 @@ class ConversationHandler(Handler):
for handler in handlers:
check = handler.check_update(context.update)
if check is not None and check is not False:
handler.handle_update(context.update, context.dispatcher, check, callback_context)
try:
handler.handle_update(context.update, context.dispatcher, check,
callback_context)
except DispatcherHandlerStop:
self.logger.warning('DispatcherHandlerStop in TIMEOUT state of '
'ConversationHandler has no effect. Ignoring.')
self.update_state(self.END, context.conversation_key)

View file

@ -60,8 +60,27 @@ def run_async(func):
class DispatcherHandlerStop(Exception):
"""Raise this in handler to prevent execution any other handler (even in different group)."""
pass
"""
Raise this in handler to prevent execution any other handler (even in different group).
In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the
optional ``state`` parameter instead of returning the next state:
.. code-block:: python
def callback(update, context):
...
raise DispatcherHandlerStop(next_state)
Attributes:
state (:obj:`object`): Optional. The next state of the conversation.
Args:
state (:obj:`object`, optional): The next state of the conversation.
"""
def __init__(self, state=None):
super().__init__()
self.state = state
class Dispatcher:

View file

@ -24,7 +24,8 @@ import pytest
from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message,
PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity)
from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler,
MessageHandler, Filters, InlineQueryHandler, CallbackContext)
MessageHandler, Filters, InlineQueryHandler, CallbackContext,
DispatcherHandlerStop, TypeHandler)
@pytest.fixture(scope='class')
@ -37,6 +38,17 @@ def user2():
return User(first_name='Mister Test', id=124, is_bot=False)
def raise_dphs(func):
def decorator(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if self.raise_dp_handler_stop:
raise DispatcherHandlerStop(result)
else:
return result
return decorator
class TestConversationHandler:
# State definitions
# At first we're thirsty. Then we brew coffee, we drink it
@ -51,9 +63,14 @@ class TestConversationHandler:
group = Chat(0, Chat.GROUP)
second_group = Chat(1, Chat.GROUP)
raise_dp_handler_stop = False
test_flag = False
# Test related
@pytest.fixture(autouse=True)
def reset(self):
self.raise_dp_handler_stop = False
self.test_flag = False
self.current_state = dict()
self.entry_points = [CommandHandler('start', self.start)]
self.states = {
@ -116,65 +133,81 @@ class TestConversationHandler:
return state
# Actions
@raise_dphs
def start(self, bot, update):
if isinstance(update, Update):
return self._set_state(update, self.THIRSTY)
else:
return self._set_state(bot, self.THIRSTY)
@raise_dphs
def end(self, bot, update):
return self._set_state(update, self.END)
@raise_dphs
def start_end(self, bot, update):
return self._set_state(update, self.END)
@raise_dphs
def start_none(self, bot, update):
return self._set_state(update, None)
@raise_dphs
def brew(self, bot, update):
if isinstance(update, Update):
return self._set_state(update, self.BREWING)
else:
return self._set_state(bot, self.BREWING)
@raise_dphs
def drink(self, bot, update):
return self._set_state(update, self.DRINKING)
@raise_dphs
def code(self, bot, update):
return self._set_state(update, self.CODING)
@raise_dphs
def passout(self, bot, update):
assert update.message.text == '/brew'
assert isinstance(update, Update)
self.is_timeout = True
@raise_dphs
def passout2(self, bot, update):
assert isinstance(update, Update)
self.is_timeout = True
@raise_dphs
def passout_context(self, update, context):
assert update.message.text == '/brew'
assert isinstance(context, CallbackContext)
self.is_timeout = True
@raise_dphs
def passout2_context(self, update, context):
assert isinstance(context, CallbackContext)
self.is_timeout = True
# Drinking actions (nested)
@raise_dphs
def hold(self, bot, update):
return self._set_state(update, self.HOLDING)
@raise_dphs
def sip(self, bot, update):
return self._set_state(update, self.SIPPING)
@raise_dphs
def swallow(self, bot, update):
return self._set_state(update, self.SWALLOWING)
@raise_dphs
def replenish(self, bot, update):
return self._set_state(update, self.REPLENISHING)
@raise_dphs
def stop(self, bot, update):
return self._set_state(update, self.STOPPING)
@ -546,6 +579,32 @@ class TestConversationHandler:
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
fallbacks=self.fallbacks, conversation_timeout=0.5)
def timeout(*args, **kwargs):
raise DispatcherHandlerStop()
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
dp.add_handler(handler)
# Start state machine, then reach timeout
message = Message(0, user1, None, self.group, text='/start',
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
offset=0, length=len('/start'))],
bot=bot)
with caplog.at_level(logging.WARNING):
dp.process_update(Update(update_id=0, message=message))
assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY
sleep(0.5)
dp.job_queue.tick()
assert handler.conversations.get((self.group.id, user1.id)) is None
assert len(caplog.records) == 1
rec = caplog.records[-1]
assert rec.msg.startswith('DispatcherHandlerStop in TIMEOUT')
def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1):
context = None
@ -953,3 +1012,129 @@ class TestConversationHandler:
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.STOPPING
assert handler.conversations.get((0, user1.id)) is None
def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2):
self.nested_states[self.DRINKING] = [ConversationHandler(
entry_points=self.drinking_entry_points,
states=self.drinking_states,
fallbacks=self.drinking_fallbacks,
map_to_parent=self.drinking_map_to_parent)]
handler = ConversationHandler(entry_points=self.entry_points,
states=self.nested_states,
fallbacks=self.fallbacks)
def test_callback(u, c):
self.test_flag = True
dp.add_handler(handler)
dp.add_handler(TypeHandler(Update, test_callback), group=1)
self.raise_dp_handler_stop = True
# User one, starts the state machine.
message = Message(0, user1, None, self.group, text='/start', bot=bot,
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
offset=0, length=len('/start'))])
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.THIRSTY
assert not self.test_flag
# The user is thirsty and wants to brew coffee.
message.text = '/brew'
message.entities[0].length = len('/brew')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.BREWING
assert not self.test_flag
# Lets pour some coffee.
message.text = '/pourCoffee'
message.entities[0].length = len('/pourCoffee')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is holding the cup
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user is sipping coffee
message.text = '/sip'
message.entities[0].length = len('/sip')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.SIPPING
assert not self.test_flag
# The user is swallowing
message.text = '/swallow'
message.entities[0].length = len('/swallow')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.SWALLOWING
assert not self.test_flag
# The user is holding the cup again
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user wants to replenish the coffee supply
message.text = '/replenish'
message.entities[0].length = len('/replenish')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.REPLENISHING
assert handler.conversations[(0, user1.id)] == self.BREWING
assert not self.test_flag
# The user wants to drink their coffee again
message.text = '/pourCoffee'
message.entities[0].length = len('/pourCoffee')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is now ready to start coding
message.text = '/startCoding'
message.entities[0].length = len('/startCoding')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.CODING
assert not self.test_flag
# The user decides it's time to drink again
message.text = '/drinkMore'
message.entities[0].length = len('/drinkMore')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user is holding their cup
message.text = '/hold'
message.entities[0].length = len('/hold')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.HOLDING
assert not self.test_flag
# The user wants to end with the drinking and go back to coding
message.text = '/end'
message.entities[0].length = len('/end')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.END
assert handler.conversations[(0, user1.id)] == self.CODING
assert not self.test_flag
# The user wants to drink once more
message.text = '/drinkMore'
message.entities[0].length = len('/drinkMore')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.DRINKING
assert not self.test_flag
# The user wants to stop altogether
message.text = '/stop'
message.entities[0].length = len('/stop')
dp.process_update(Update(update_id=0, message=message))
assert self.current_state[user1.id] == self.STOPPING
assert handler.conversations.get((0, user1.id)) is None
assert not self.test_flag