mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Fix Non-Blocking Entry Point in ConversationHandler
(#3068)
This commit is contained in:
parent
42276338b1
commit
22419c0464
2 changed files with 65 additions and 9 deletions
|
@ -92,8 +92,9 @@ class PendingState:
|
|||
|
||||
def resolve(self) -> object:
|
||||
"""Returns the new state of the :class:`ConversationHandler` if available. If there was an
|
||||
exception during the task execution, then return the old state. If the returned state was
|
||||
:obj:`None`, then end the conversation.
|
||||
exception during the task execution, then return the old state. If both the new and old
|
||||
state are :obj:`None`, return `CH.END`. If only the new state is :obj:`None`, return the
|
||||
old state.
|
||||
|
||||
Raises:
|
||||
:exc:`RuntimeError`: If the current task has not yet finished.
|
||||
|
@ -106,13 +107,15 @@ class PendingState:
|
|||
_logger.exception(
|
||||
"Task function raised exception. Falling back to old state %s",
|
||||
self.old_state,
|
||||
exc_info=exc,
|
||||
)
|
||||
return self.old_state
|
||||
|
||||
res = self.task.result()
|
||||
if res is None and self.old_state is None:
|
||||
res = ConversationHandler.END
|
||||
elif res is None:
|
||||
# returning None from a callback means that we want to stay in the old state
|
||||
return self.old_state
|
||||
|
||||
return res
|
||||
|
||||
|
@ -708,6 +711,11 @@ class ConversationHandler(BaseHandler[Update, CCT]):
|
|||
# check if future is finished or not
|
||||
if state.done():
|
||||
res = state.resolve()
|
||||
# Special case if an error was raised in a non-blocking entry-point
|
||||
if state.old_state is None and state.task.exception():
|
||||
self._conversations.pop(key, None)
|
||||
state = None
|
||||
else:
|
||||
self._update_state(res, key)
|
||||
state = self._conversations.get(key)
|
||||
|
||||
|
|
|
@ -1050,9 +1050,10 @@ class TestConversationHandler:
|
|||
|
||||
await app.stop()
|
||||
|
||||
async def test_non_blocking_exception(self, app, bot, user1, caplog):
|
||||
@pytest.mark.parametrize(argnames="test_type", argvalues=["none", "exception"])
|
||||
async def test_non_blocking_exception_or_none(self, app, bot, user1, caplog, test_type):
|
||||
"""Here we make sure that when a non-blocking handler raises an
|
||||
exception, the state isn't changed.
|
||||
exception or returns None, the state isn't changed.
|
||||
"""
|
||||
error = Exception("task exception")
|
||||
|
||||
|
@ -1060,6 +1061,8 @@ class TestConversationHandler:
|
|||
return 1
|
||||
|
||||
async def raise_error(*a, **kw):
|
||||
if test_type == "none":
|
||||
return None
|
||||
raise error
|
||||
|
||||
handler = ConversationHandler(
|
||||
|
@ -1092,12 +1095,57 @@ class TestConversationHandler:
|
|||
with caplog.at_level(logging.ERROR):
|
||||
# This also makes sure that we're still in the same state
|
||||
assert handler.check_update(Update(0, message=message))
|
||||
if test_type == "exception":
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[0].message
|
||||
== "Task function raised exception. Falling back to old state 1"
|
||||
)
|
||||
assert caplog.records[0].exc_info[1] is error
|
||||
assert caplog.records[0].exc_info[1] is None
|
||||
else:
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
async def test_non_blocking_entry_point_exception(self, app, bot, user1, caplog):
|
||||
"""Here we make sure that when a non-blocking entry point raises an
|
||||
exception, the state isn't changed.
|
||||
"""
|
||||
error = Exception("task exception")
|
||||
|
||||
async def raise_error(*a, **kw):
|
||||
raise error
|
||||
|
||||
handler = ConversationHandler(
|
||||
entry_points=[CommandHandler("start", raise_error, block=False)],
|
||||
states={},
|
||||
fallbacks=self.fallbacks,
|
||||
)
|
||||
app.add_handler(handler)
|
||||
|
||||
message = Message(
|
||||
0,
|
||||
None,
|
||||
self.group,
|
||||
from_user=user1,
|
||||
text="/start",
|
||||
entities=[
|
||||
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start"))
|
||||
],
|
||||
bot=bot,
|
||||
)
|
||||
# start the conversation
|
||||
async with app:
|
||||
await app.process_update(Update(update_id=0, message=message))
|
||||
await asyncio.sleep(0.1)
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.ERROR):
|
||||
# This also makes sure that we're still in the same state
|
||||
assert handler.check_update(Update(0, message=message))
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[0].message
|
||||
== "Task function raised exception. Falling back to old state None"
|
||||
)
|
||||
assert caplog.records[0].exc_info[1] is None
|
||||
|
||||
async def test_conversation_timeout(self, app, bot, user1):
|
||||
handler = ConversationHandler(
|
||||
|
|
Loading…
Reference in a new issue