mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-11-23 07:38:58 +01:00
Allow DispatcherHandlerStop in ConversationHandler (#2059)
* First go * Fix bug with nested convs
This commit is contained in:
parent
3304cc5c90
commit
da452df07d
3 changed files with 228 additions and 7 deletions
|
@ -24,7 +24,7 @@ from threading import Lock
|
|||
|
||||
from telegram import Update
|
||||
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
||||
ChosenInlineResultHandler, CallbackContext)
|
||||
ChosenInlineResultHandler, CallbackContext, DispatcherHandlerStop)
|
||||
from telegram.utils.promise import Promise
|
||||
|
||||
|
||||
|
@ -454,6 +454,7 @@ class ConversationHandler(Handler):
|
|||
|
||||
"""
|
||||
conversation_key, handler, check_result = check_result
|
||||
raise_dp_handler_stop = False
|
||||
|
||||
with self._timeout_jobs_lock:
|
||||
# Remove the old timeout job (if present)
|
||||
|
@ -462,7 +463,11 @@ class ConversationHandler(Handler):
|
|||
if timeout_job is not None:
|
||||
timeout_job.schedule_removal()
|
||||
|
||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||
try:
|
||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||
except DispatcherHandlerStop as e:
|
||||
new_state = e.state
|
||||
raise_dp_handler_stop = True
|
||||
|
||||
with self._timeout_jobs_lock:
|
||||
if self.conversation_timeout and new_state != self.END:
|
||||
|
@ -474,9 +479,16 @@ class ConversationHandler(Handler):
|
|||
|
||||
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
|
||||
self.update_state(self.END, conversation_key)
|
||||
return self.map_to_parent.get(new_state)
|
||||
if raise_dp_handler_stop:
|
||||
raise DispatcherHandlerStop(self.map_to_parent.get(new_state))
|
||||
else:
|
||||
return self.map_to_parent.get(new_state)
|
||||
else:
|
||||
self.update_state(new_state, conversation_key)
|
||||
if raise_dp_handler_stop:
|
||||
# Don't pass the new state here. If we're in a nested conversation, the parent is
|
||||
# expecting None as return value.
|
||||
raise DispatcherHandlerStop()
|
||||
|
||||
def update_state(self, new_state, key):
|
||||
if new_state == self.END:
|
||||
|
@ -522,5 +534,10 @@ class ConversationHandler(Handler):
|
|||
for handler in handlers:
|
||||
check = handler.check_update(context.update)
|
||||
if check is not None and check is not False:
|
||||
handler.handle_update(context.update, context.dispatcher, check, callback_context)
|
||||
try:
|
||||
handler.handle_update(context.update, context.dispatcher, check,
|
||||
callback_context)
|
||||
except DispatcherHandlerStop:
|
||||
self.logger.warning('DispatcherHandlerStop in TIMEOUT state of '
|
||||
'ConversationHandler has no effect. Ignoring.')
|
||||
self.update_state(self.END, context.conversation_key)
|
||||
|
|
|
@ -60,8 +60,27 @@ def run_async(func):
|
|||
|
||||
|
||||
class DispatcherHandlerStop(Exception):
|
||||
"""Raise this in handler to prevent execution any other handler (even in different group)."""
|
||||
pass
|
||||
"""
|
||||
Raise this in handler to prevent execution any other handler (even in different group).
|
||||
|
||||
In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the
|
||||
optional ``state`` parameter instead of returning the next state:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def callback(update, context):
|
||||
...
|
||||
raise DispatcherHandlerStop(next_state)
|
||||
|
||||
Attributes:
|
||||
state (:obj:`object`): Optional. The next state of the conversation.
|
||||
|
||||
Args:
|
||||
state (:obj:`object`, optional): The next state of the conversation.
|
||||
"""
|
||||
def __init__(self, state=None):
|
||||
super().__init__()
|
||||
self.state = state
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
|
|
|
@ -24,7 +24,8 @@ import pytest
|
|||
from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message,
|
||||
PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity)
|
||||
from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler,
|
||||
MessageHandler, Filters, InlineQueryHandler, CallbackContext)
|
||||
MessageHandler, Filters, InlineQueryHandler, CallbackContext,
|
||||
DispatcherHandlerStop, TypeHandler)
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
|
@ -37,6 +38,17 @@ def user2():
|
|||
return User(first_name='Mister Test', id=124, is_bot=False)
|
||||
|
||||
|
||||
def raise_dphs(func):
|
||||
def decorator(self, *args, **kwargs):
|
||||
result = func(self, *args, **kwargs)
|
||||
if self.raise_dp_handler_stop:
|
||||
raise DispatcherHandlerStop(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TestConversationHandler:
|
||||
# State definitions
|
||||
# At first we're thirsty. Then we brew coffee, we drink it
|
||||
|
@ -51,9 +63,14 @@ class TestConversationHandler:
|
|||
group = Chat(0, Chat.GROUP)
|
||||
second_group = Chat(1, Chat.GROUP)
|
||||
|
||||
raise_dp_handler_stop = False
|
||||
test_flag = False
|
||||
|
||||
# Test related
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset(self):
|
||||
self.raise_dp_handler_stop = False
|
||||
self.test_flag = False
|
||||
self.current_state = dict()
|
||||
self.entry_points = [CommandHandler('start', self.start)]
|
||||
self.states = {
|
||||
|
@ -116,65 +133,81 @@ class TestConversationHandler:
|
|||
return state
|
||||
|
||||
# Actions
|
||||
@raise_dphs
|
||||
def start(self, bot, update):
|
||||
if isinstance(update, Update):
|
||||
return self._set_state(update, self.THIRSTY)
|
||||
else:
|
||||
return self._set_state(bot, self.THIRSTY)
|
||||
|
||||
@raise_dphs
|
||||
def end(self, bot, update):
|
||||
return self._set_state(update, self.END)
|
||||
|
||||
@raise_dphs
|
||||
def start_end(self, bot, update):
|
||||
return self._set_state(update, self.END)
|
||||
|
||||
@raise_dphs
|
||||
def start_none(self, bot, update):
|
||||
return self._set_state(update, None)
|
||||
|
||||
@raise_dphs
|
||||
def brew(self, bot, update):
|
||||
if isinstance(update, Update):
|
||||
return self._set_state(update, self.BREWING)
|
||||
else:
|
||||
return self._set_state(bot, self.BREWING)
|
||||
|
||||
@raise_dphs
|
||||
def drink(self, bot, update):
|
||||
return self._set_state(update, self.DRINKING)
|
||||
|
||||
@raise_dphs
|
||||
def code(self, bot, update):
|
||||
return self._set_state(update, self.CODING)
|
||||
|
||||
@raise_dphs
|
||||
def passout(self, bot, update):
|
||||
assert update.message.text == '/brew'
|
||||
assert isinstance(update, Update)
|
||||
self.is_timeout = True
|
||||
|
||||
@raise_dphs
|
||||
def passout2(self, bot, update):
|
||||
assert isinstance(update, Update)
|
||||
self.is_timeout = True
|
||||
|
||||
@raise_dphs
|
||||
def passout_context(self, update, context):
|
||||
assert update.message.text == '/brew'
|
||||
assert isinstance(context, CallbackContext)
|
||||
self.is_timeout = True
|
||||
|
||||
@raise_dphs
|
||||
def passout2_context(self, update, context):
|
||||
assert isinstance(context, CallbackContext)
|
||||
self.is_timeout = True
|
||||
|
||||
# Drinking actions (nested)
|
||||
|
||||
@raise_dphs
|
||||
def hold(self, bot, update):
|
||||
return self._set_state(update, self.HOLDING)
|
||||
|
||||
@raise_dphs
|
||||
def sip(self, bot, update):
|
||||
return self._set_state(update, self.SIPPING)
|
||||
|
||||
@raise_dphs
|
||||
def swallow(self, bot, update):
|
||||
return self._set_state(update, self.SWALLOWING)
|
||||
|
||||
@raise_dphs
|
||||
def replenish(self, bot, update):
|
||||
return self._set_state(update, self.REPLENISHING)
|
||||
|
||||
@raise_dphs
|
||||
def stop(self, bot, update):
|
||||
return self._set_state(update, self.STOPPING)
|
||||
|
||||
|
@ -546,6 +579,32 @@ class TestConversationHandler:
|
|||
dp.job_queue.tick()
|
||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
|
||||
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
|
||||
handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
|
||||
fallbacks=self.fallbacks, conversation_timeout=0.5)
|
||||
|
||||
def timeout(*args, **kwargs):
|
||||
raise DispatcherHandlerStop()
|
||||
|
||||
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
|
||||
dp.add_handler(handler)
|
||||
|
||||
# Start state machine, then reach timeout
|
||||
message = Message(0, user1, None, self.group, text='/start',
|
||||
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
|
||||
offset=0, length=len('/start'))],
|
||||
bot=bot)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY
|
||||
sleep(0.5)
|
||||
dp.job_queue.tick()
|
||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
assert len(caplog.records) == 1
|
||||
rec = caplog.records[-1]
|
||||
assert rec.msg.startswith('DispatcherHandlerStop in TIMEOUT')
|
||||
|
||||
def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1):
|
||||
context = None
|
||||
|
||||
|
@ -953,3 +1012,129 @@ class TestConversationHandler:
|
|||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.STOPPING
|
||||
assert handler.conversations.get((0, user1.id)) is None
|
||||
|
||||
def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2):
|
||||
self.nested_states[self.DRINKING] = [ConversationHandler(
|
||||
entry_points=self.drinking_entry_points,
|
||||
states=self.drinking_states,
|
||||
fallbacks=self.drinking_fallbacks,
|
||||
map_to_parent=self.drinking_map_to_parent)]
|
||||
handler = ConversationHandler(entry_points=self.entry_points,
|
||||
states=self.nested_states,
|
||||
fallbacks=self.fallbacks)
|
||||
|
||||
def test_callback(u, c):
|
||||
self.test_flag = True
|
||||
|
||||
dp.add_handler(handler)
|
||||
dp.add_handler(TypeHandler(Update, test_callback), group=1)
|
||||
self.raise_dp_handler_stop = True
|
||||
|
||||
# User one, starts the state machine.
|
||||
message = Message(0, user1, None, self.group, text='/start', bot=bot,
|
||||
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
|
||||
offset=0, length=len('/start'))])
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.THIRSTY
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is thirsty and wants to brew coffee.
|
||||
message.text = '/brew'
|
||||
message.entities[0].length = len('/brew')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.BREWING
|
||||
assert not self.test_flag
|
||||
|
||||
# Lets pour some coffee.
|
||||
message.text = '/pourCoffee'
|
||||
message.entities[0].length = len('/pourCoffee')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.DRINKING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is holding the cup
|
||||
message.text = '/hold'
|
||||
message.entities[0].length = len('/hold')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.HOLDING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is sipping coffee
|
||||
message.text = '/sip'
|
||||
message.entities[0].length = len('/sip')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.SIPPING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is swallowing
|
||||
message.text = '/swallow'
|
||||
message.entities[0].length = len('/swallow')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.SWALLOWING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is holding the cup again
|
||||
message.text = '/hold'
|
||||
message.entities[0].length = len('/hold')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.HOLDING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user wants to replenish the coffee supply
|
||||
message.text = '/replenish'
|
||||
message.entities[0].length = len('/replenish')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.REPLENISHING
|
||||
assert handler.conversations[(0, user1.id)] == self.BREWING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user wants to drink their coffee again
|
||||
message.text = '/pourCoffee'
|
||||
message.entities[0].length = len('/pourCoffee')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.DRINKING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is now ready to start coding
|
||||
message.text = '/startCoding'
|
||||
message.entities[0].length = len('/startCoding')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.CODING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user decides it's time to drink again
|
||||
message.text = '/drinkMore'
|
||||
message.entities[0].length = len('/drinkMore')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.DRINKING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user is holding their cup
|
||||
message.text = '/hold'
|
||||
message.entities[0].length = len('/hold')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.HOLDING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user wants to end with the drinking and go back to coding
|
||||
message.text = '/end'
|
||||
message.entities[0].length = len('/end')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.END
|
||||
assert handler.conversations[(0, user1.id)] == self.CODING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user wants to drink once more
|
||||
message.text = '/drinkMore'
|
||||
message.entities[0].length = len('/drinkMore')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.DRINKING
|
||||
assert not self.test_flag
|
||||
|
||||
# The user wants to stop altogether
|
||||
message.text = '/stop'
|
||||
message.entities[0].length = len('/stop')
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
assert self.current_state[user1.id] == self.STOPPING
|
||||
assert handler.conversations.get((0, user1.id)) is None
|
||||
assert not self.test_flag
|
||||
|
|
Loading…
Reference in a new issue