Fix Non-Blocking Entry Point in ConversationHandler (#3068)

This commit is contained in:
Bibo-Joshi 2022-06-07 17:48:26 +02:00 committed by GitHub
parent 42276338b1
commit 22419c0464
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 9 deletions

View file

@ -92,8 +92,9 @@ class PendingState:
def resolve(self) -> object:
"""Returns the new state of the :class:`ConversationHandler` if available. If there was an
exception during the task execution, then return the old state. If the returned state was
:obj:`None`, then end the conversation.
exception during the task execution, then return the old state. If both the new and old
state are :obj:`None`, return `CH.END`. If only the new state is :obj:`None`, return the
old state.
Raises:
:exc:`RuntimeError`: If the current task has not yet finished.
@ -106,13 +107,15 @@ class PendingState:
_logger.exception(
"Task function raised exception. Falling back to old state %s",
self.old_state,
exc_info=exc,
)
return self.old_state
res = self.task.result()
if res is None and self.old_state is None:
res = ConversationHandler.END
elif res is None:
# returning None from a callback means that we want to stay in the old state
return self.old_state
return res
@ -708,6 +711,11 @@ class ConversationHandler(BaseHandler[Update, CCT]):
# check if future is finished or not
if state.done():
res = state.resolve()
# Special case if an error was raised in a non-blocking entry-point
if state.old_state is None and state.task.exception():
self._conversations.pop(key, None)
state = None
else:
self._update_state(res, key)
state = self._conversations.get(key)

View file

@ -1050,9 +1050,10 @@ class TestConversationHandler:
await app.stop()
async def test_non_blocking_exception(self, app, bot, user1, caplog):
@pytest.mark.parametrize(argnames="test_type", argvalues=["none", "exception"])
async def test_non_blocking_exception_or_none(self, app, bot, user1, caplog, test_type):
"""Here we make sure that when a non-blocking handler raises an
exception, the state isn't changed.
exception or returns None, the state isn't changed.
"""
error = Exception("task exception")
@ -1060,6 +1061,8 @@ class TestConversationHandler:
return 1
async def raise_error(*a, **kw):
if test_type == "none":
return None
raise error
handler = ConversationHandler(
@ -1092,12 +1095,57 @@ class TestConversationHandler:
with caplog.at_level(logging.ERROR):
# This also makes sure that we're still in the same state
assert handler.check_update(Update(0, message=message))
if test_type == "exception":
assert len(caplog.records) == 1
assert (
caplog.records[0].message
== "Task function raised exception. Falling back to old state 1"
)
assert caplog.records[0].exc_info[1] is error
assert caplog.records[0].exc_info[1] is None
else:
assert len(caplog.records) == 0
async def test_non_blocking_entry_point_exception(self, app, bot, user1, caplog):
"""Here we make sure that when a non-blocking entry point raises an
exception, the state isn't changed.
"""
error = Exception("task exception")
async def raise_error(*a, **kw):
raise error
handler = ConversationHandler(
entry_points=[CommandHandler("start", raise_error, block=False)],
states={},
fallbacks=self.fallbacks,
)
app.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text="/start",
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start"))
],
bot=bot,
)
# start the conversation
async with app:
await app.process_update(Update(update_id=0, message=message))
await asyncio.sleep(0.1)
caplog.clear()
with caplog.at_level(logging.ERROR):
# This also makes sure that we're still in the same state
assert handler.check_update(Update(0, message=message))
assert len(caplog.records) == 1
assert (
caplog.records[0].message
== "Task function raised exception. Falling back to old state None"
)
assert caplog.records[0].exc_info[1] is None
async def test_conversation_timeout(self, app, bot, user1):
handler = ConversationHandler(