mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-10-23 17:36:26 +02:00
Allow DispatcherHandlerStop in ConversationHandler (#2059)
* First go * Fix bug with nested convs
This commit is contained in:
parent
3304cc5c90
commit
da452df07d
3 changed files with 228 additions and 7 deletions
|
@ -24,7 +24,7 @@ from threading import Lock
|
||||||
|
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
||||||
ChosenInlineResultHandler, CallbackContext)
|
ChosenInlineResultHandler, CallbackContext, DispatcherHandlerStop)
|
||||||
from telegram.utils.promise import Promise
|
from telegram.utils.promise import Promise
|
||||||
|
|
||||||
|
|
||||||
|
@ -454,6 +454,7 @@ class ConversationHandler(Handler):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
conversation_key, handler, check_result = check_result
|
conversation_key, handler, check_result = check_result
|
||||||
|
raise_dp_handler_stop = False
|
||||||
|
|
||||||
with self._timeout_jobs_lock:
|
with self._timeout_jobs_lock:
|
||||||
# Remove the old timeout job (if present)
|
# Remove the old timeout job (if present)
|
||||||
|
@ -462,7 +463,11 @@ class ConversationHandler(Handler):
|
||||||
if timeout_job is not None:
|
if timeout_job is not None:
|
||||||
timeout_job.schedule_removal()
|
timeout_job.schedule_removal()
|
||||||
|
|
||||||
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
try:
|
||||||
|
new_state = handler.handle_update(update, dispatcher, check_result, context)
|
||||||
|
except DispatcherHandlerStop as e:
|
||||||
|
new_state = e.state
|
||||||
|
raise_dp_handler_stop = True
|
||||||
|
|
||||||
with self._timeout_jobs_lock:
|
with self._timeout_jobs_lock:
|
||||||
if self.conversation_timeout and new_state != self.END:
|
if self.conversation_timeout and new_state != self.END:
|
||||||
|
@ -474,9 +479,16 @@ class ConversationHandler(Handler):
|
||||||
|
|
||||||
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
|
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
|
||||||
self.update_state(self.END, conversation_key)
|
self.update_state(self.END, conversation_key)
|
||||||
return self.map_to_parent.get(new_state)
|
if raise_dp_handler_stop:
|
||||||
|
raise DispatcherHandlerStop(self.map_to_parent.get(new_state))
|
||||||
|
else:
|
||||||
|
return self.map_to_parent.get(new_state)
|
||||||
else:
|
else:
|
||||||
self.update_state(new_state, conversation_key)
|
self.update_state(new_state, conversation_key)
|
||||||
|
if raise_dp_handler_stop:
|
||||||
|
# Don't pass the new state here. If we're in a nested conversation, the parent is
|
||||||
|
# expecting None as return value.
|
||||||
|
raise DispatcherHandlerStop()
|
||||||
|
|
||||||
def update_state(self, new_state, key):
|
def update_state(self, new_state, key):
|
||||||
if new_state == self.END:
|
if new_state == self.END:
|
||||||
|
@ -522,5 +534,10 @@ class ConversationHandler(Handler):
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
check = handler.check_update(context.update)
|
check = handler.check_update(context.update)
|
||||||
if check is not None and check is not False:
|
if check is not None and check is not False:
|
||||||
handler.handle_update(context.update, context.dispatcher, check, callback_context)
|
try:
|
||||||
|
handler.handle_update(context.update, context.dispatcher, check,
|
||||||
|
callback_context)
|
||||||
|
except DispatcherHandlerStop:
|
||||||
|
self.logger.warning('DispatcherHandlerStop in TIMEOUT state of '
|
||||||
|
'ConversationHandler has no effect. Ignoring.')
|
||||||
self.update_state(self.END, context.conversation_key)
|
self.update_state(self.END, context.conversation_key)
|
||||||
|
|
|
@ -60,8 +60,27 @@ def run_async(func):
|
||||||
|
|
||||||
|
|
||||||
class DispatcherHandlerStop(Exception):
|
class DispatcherHandlerStop(Exception):
|
||||||
"""Raise this in handler to prevent execution any other handler (even in different group)."""
|
"""
|
||||||
pass
|
Raise this in handler to prevent execution any other handler (even in different group).
|
||||||
|
|
||||||
|
In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the
|
||||||
|
optional ``state`` parameter instead of returning the next state:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def callback(update, context):
|
||||||
|
...
|
||||||
|
raise DispatcherHandlerStop(next_state)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
state (:obj:`object`): Optional. The next state of the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (:obj:`object`, optional): The next state of the conversation.
|
||||||
|
"""
|
||||||
|
def __init__(self, state=None):
|
||||||
|
super().__init__()
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
|
|
|
@ -24,7 +24,8 @@ import pytest
|
||||||
from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message,
|
from telegram import (CallbackQuery, Chat, ChosenInlineResult, InlineQuery, Message,
|
||||||
PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity)
|
PreCheckoutQuery, ShippingQuery, Update, User, MessageEntity)
|
||||||
from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler,
|
from telegram.ext import (ConversationHandler, CommandHandler, CallbackQueryHandler,
|
||||||
MessageHandler, Filters, InlineQueryHandler, CallbackContext)
|
MessageHandler, Filters, InlineQueryHandler, CallbackContext,
|
||||||
|
DispatcherHandlerStop, TypeHandler)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='class')
|
@pytest.fixture(scope='class')
|
||||||
|
@ -37,6 +38,17 @@ def user2():
|
||||||
return User(first_name='Mister Test', id=124, is_bot=False)
|
return User(first_name='Mister Test', id=124, is_bot=False)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_dphs(func):
|
||||||
|
def decorator(self, *args, **kwargs):
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
if self.raise_dp_handler_stop:
|
||||||
|
raise DispatcherHandlerStop(result)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class TestConversationHandler:
|
class TestConversationHandler:
|
||||||
# State definitions
|
# State definitions
|
||||||
# At first we're thirsty. Then we brew coffee, we drink it
|
# At first we're thirsty. Then we brew coffee, we drink it
|
||||||
|
@ -51,9 +63,14 @@ class TestConversationHandler:
|
||||||
group = Chat(0, Chat.GROUP)
|
group = Chat(0, Chat.GROUP)
|
||||||
second_group = Chat(1, Chat.GROUP)
|
second_group = Chat(1, Chat.GROUP)
|
||||||
|
|
||||||
|
raise_dp_handler_stop = False
|
||||||
|
test_flag = False
|
||||||
|
|
||||||
# Test related
|
# Test related
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
self.raise_dp_handler_stop = False
|
||||||
|
self.test_flag = False
|
||||||
self.current_state = dict()
|
self.current_state = dict()
|
||||||
self.entry_points = [CommandHandler('start', self.start)]
|
self.entry_points = [CommandHandler('start', self.start)]
|
||||||
self.states = {
|
self.states = {
|
||||||
|
@ -116,65 +133,81 @@ class TestConversationHandler:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
|
@raise_dphs
|
||||||
def start(self, bot, update):
|
def start(self, bot, update):
|
||||||
if isinstance(update, Update):
|
if isinstance(update, Update):
|
||||||
return self._set_state(update, self.THIRSTY)
|
return self._set_state(update, self.THIRSTY)
|
||||||
else:
|
else:
|
||||||
return self._set_state(bot, self.THIRSTY)
|
return self._set_state(bot, self.THIRSTY)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def end(self, bot, update):
|
def end(self, bot, update):
|
||||||
return self._set_state(update, self.END)
|
return self._set_state(update, self.END)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def start_end(self, bot, update):
|
def start_end(self, bot, update):
|
||||||
return self._set_state(update, self.END)
|
return self._set_state(update, self.END)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def start_none(self, bot, update):
|
def start_none(self, bot, update):
|
||||||
return self._set_state(update, None)
|
return self._set_state(update, None)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def brew(self, bot, update):
|
def brew(self, bot, update):
|
||||||
if isinstance(update, Update):
|
if isinstance(update, Update):
|
||||||
return self._set_state(update, self.BREWING)
|
return self._set_state(update, self.BREWING)
|
||||||
else:
|
else:
|
||||||
return self._set_state(bot, self.BREWING)
|
return self._set_state(bot, self.BREWING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def drink(self, bot, update):
|
def drink(self, bot, update):
|
||||||
return self._set_state(update, self.DRINKING)
|
return self._set_state(update, self.DRINKING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def code(self, bot, update):
|
def code(self, bot, update):
|
||||||
return self._set_state(update, self.CODING)
|
return self._set_state(update, self.CODING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def passout(self, bot, update):
|
def passout(self, bot, update):
|
||||||
assert update.message.text == '/brew'
|
assert update.message.text == '/brew'
|
||||||
assert isinstance(update, Update)
|
assert isinstance(update, Update)
|
||||||
self.is_timeout = True
|
self.is_timeout = True
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def passout2(self, bot, update):
|
def passout2(self, bot, update):
|
||||||
assert isinstance(update, Update)
|
assert isinstance(update, Update)
|
||||||
self.is_timeout = True
|
self.is_timeout = True
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def passout_context(self, update, context):
|
def passout_context(self, update, context):
|
||||||
assert update.message.text == '/brew'
|
assert update.message.text == '/brew'
|
||||||
assert isinstance(context, CallbackContext)
|
assert isinstance(context, CallbackContext)
|
||||||
self.is_timeout = True
|
self.is_timeout = True
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def passout2_context(self, update, context):
|
def passout2_context(self, update, context):
|
||||||
assert isinstance(context, CallbackContext)
|
assert isinstance(context, CallbackContext)
|
||||||
self.is_timeout = True
|
self.is_timeout = True
|
||||||
|
|
||||||
# Drinking actions (nested)
|
# Drinking actions (nested)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def hold(self, bot, update):
|
def hold(self, bot, update):
|
||||||
return self._set_state(update, self.HOLDING)
|
return self._set_state(update, self.HOLDING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def sip(self, bot, update):
|
def sip(self, bot, update):
|
||||||
return self._set_state(update, self.SIPPING)
|
return self._set_state(update, self.SIPPING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def swallow(self, bot, update):
|
def swallow(self, bot, update):
|
||||||
return self._set_state(update, self.SWALLOWING)
|
return self._set_state(update, self.SWALLOWING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def replenish(self, bot, update):
|
def replenish(self, bot, update):
|
||||||
return self._set_state(update, self.REPLENISHING)
|
return self._set_state(update, self.REPLENISHING)
|
||||||
|
|
||||||
|
@raise_dphs
|
||||||
def stop(self, bot, update):
|
def stop(self, bot, update):
|
||||||
return self._set_state(update, self.STOPPING)
|
return self._set_state(update, self.STOPPING)
|
||||||
|
|
||||||
|
@ -546,6 +579,32 @@ class TestConversationHandler:
|
||||||
dp.job_queue.tick()
|
dp.job_queue.tick()
|
||||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||||
|
|
||||||
|
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
|
||||||
|
handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
|
||||||
|
fallbacks=self.fallbacks, conversation_timeout=0.5)
|
||||||
|
|
||||||
|
def timeout(*args, **kwargs):
|
||||||
|
raise DispatcherHandlerStop()
|
||||||
|
|
||||||
|
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
|
||||||
|
dp.add_handler(handler)
|
||||||
|
|
||||||
|
# Start state machine, then reach timeout
|
||||||
|
message = Message(0, user1, None, self.group, text='/start',
|
||||||
|
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
|
||||||
|
offset=0, length=len('/start'))],
|
||||||
|
bot=bot)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY
|
||||||
|
sleep(0.5)
|
||||||
|
dp.job_queue.tick()
|
||||||
|
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||||
|
assert len(caplog.records) == 1
|
||||||
|
rec = caplog.records[-1]
|
||||||
|
assert rec.msg.startswith('DispatcherHandlerStop in TIMEOUT')
|
||||||
|
|
||||||
def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1):
|
def test_conversation_handler_timeout_update_and_context(self, cdp, bot, user1):
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
|
@ -953,3 +1012,129 @@ class TestConversationHandler:
|
||||||
dp.process_update(Update(update_id=0, message=message))
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
assert self.current_state[user1.id] == self.STOPPING
|
assert self.current_state[user1.id] == self.STOPPING
|
||||||
assert handler.conversations.get((0, user1.id)) is None
|
assert handler.conversations.get((0, user1.id)) is None
|
||||||
|
|
||||||
|
def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2):
|
||||||
|
self.nested_states[self.DRINKING] = [ConversationHandler(
|
||||||
|
entry_points=self.drinking_entry_points,
|
||||||
|
states=self.drinking_states,
|
||||||
|
fallbacks=self.drinking_fallbacks,
|
||||||
|
map_to_parent=self.drinking_map_to_parent)]
|
||||||
|
handler = ConversationHandler(entry_points=self.entry_points,
|
||||||
|
states=self.nested_states,
|
||||||
|
fallbacks=self.fallbacks)
|
||||||
|
|
||||||
|
def test_callback(u, c):
|
||||||
|
self.test_flag = True
|
||||||
|
|
||||||
|
dp.add_handler(handler)
|
||||||
|
dp.add_handler(TypeHandler(Update, test_callback), group=1)
|
||||||
|
self.raise_dp_handler_stop = True
|
||||||
|
|
||||||
|
# User one, starts the state machine.
|
||||||
|
message = Message(0, user1, None, self.group, text='/start', bot=bot,
|
||||||
|
entities=[MessageEntity(type=MessageEntity.BOT_COMMAND,
|
||||||
|
offset=0, length=len('/start'))])
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.THIRSTY
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is thirsty and wants to brew coffee.
|
||||||
|
message.text = '/brew'
|
||||||
|
message.entities[0].length = len('/brew')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.BREWING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# Lets pour some coffee.
|
||||||
|
message.text = '/pourCoffee'
|
||||||
|
message.entities[0].length = len('/pourCoffee')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.DRINKING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is holding the cup
|
||||||
|
message.text = '/hold'
|
||||||
|
message.entities[0].length = len('/hold')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.HOLDING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is sipping coffee
|
||||||
|
message.text = '/sip'
|
||||||
|
message.entities[0].length = len('/sip')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.SIPPING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is swallowing
|
||||||
|
message.text = '/swallow'
|
||||||
|
message.entities[0].length = len('/swallow')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.SWALLOWING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is holding the cup again
|
||||||
|
message.text = '/hold'
|
||||||
|
message.entities[0].length = len('/hold')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.HOLDING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user wants to replenish the coffee supply
|
||||||
|
message.text = '/replenish'
|
||||||
|
message.entities[0].length = len('/replenish')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.REPLENISHING
|
||||||
|
assert handler.conversations[(0, user1.id)] == self.BREWING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user wants to drink their coffee again
|
||||||
|
message.text = '/pourCoffee'
|
||||||
|
message.entities[0].length = len('/pourCoffee')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.DRINKING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is now ready to start coding
|
||||||
|
message.text = '/startCoding'
|
||||||
|
message.entities[0].length = len('/startCoding')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.CODING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user decides it's time to drink again
|
||||||
|
message.text = '/drinkMore'
|
||||||
|
message.entities[0].length = len('/drinkMore')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.DRINKING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user is holding their cup
|
||||||
|
message.text = '/hold'
|
||||||
|
message.entities[0].length = len('/hold')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.HOLDING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user wants to end with the drinking and go back to coding
|
||||||
|
message.text = '/end'
|
||||||
|
message.entities[0].length = len('/end')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.END
|
||||||
|
assert handler.conversations[(0, user1.id)] == self.CODING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user wants to drink once more
|
||||||
|
message.text = '/drinkMore'
|
||||||
|
message.entities[0].length = len('/drinkMore')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.DRINKING
|
||||||
|
assert not self.test_flag
|
||||||
|
|
||||||
|
# The user wants to stop altogether
|
||||||
|
message.text = '/stop'
|
||||||
|
message.entities[0].length = len('/stop')
|
||||||
|
dp.process_update(Update(update_id=0, message=message))
|
||||||
|
assert self.current_state[user1.id] == self.STOPPING
|
||||||
|
assert handler.conversations.get((0, user1.id)) is None
|
||||||
|
assert not self.test_flag
|
||||||
|
|
Loading…
Reference in a new issue