Fix Non-Blocking Entry Point in ConversationHandler (#3068)

This commit is contained in:
Bibo-Joshi 2022-06-07 17:48:26 +02:00 committed by GitHub
parent 42276338b1
commit 22419c0464
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 9 deletions

View file

@ -92,8 +92,9 @@ class PendingState:
def resolve(self) -> object: def resolve(self) -> object:
"""Returns the new state of the :class:`ConversationHandler` if available. If there was an """Returns the new state of the :class:`ConversationHandler` if available. If there was an
exception during the task execution, then return the old state. If the returned state was exception during the task execution, then return the old state. If both the new and old
:obj:`None`, then end the conversation. state are :obj:`None`, return `CH.END`. If only the new state is :obj:`None`, return the
old state.
Raises: Raises:
:exc:`RuntimeError`: If the current task has not yet finished. :exc:`RuntimeError`: If the current task has not yet finished.
@ -106,13 +107,15 @@ class PendingState:
_logger.exception( _logger.exception(
"Task function raised exception. Falling back to old state %s", "Task function raised exception. Falling back to old state %s",
self.old_state, self.old_state,
exc_info=exc,
) )
return self.old_state return self.old_state
res = self.task.result() res = self.task.result()
if res is None and self.old_state is None: if res is None and self.old_state is None:
res = ConversationHandler.END res = ConversationHandler.END
elif res is None:
# returning None from a callback means that we want to stay in the old state
return self.old_state
return res return res
@ -708,6 +711,11 @@ class ConversationHandler(BaseHandler[Update, CCT]):
# check if future is finished or not # check if future is finished or not
if state.done(): if state.done():
res = state.resolve() res = state.resolve()
# Special case if an error was raised in a non-blocking entry-point
if state.old_state is None and state.task.exception():
self._conversations.pop(key, None)
state = None
else:
self._update_state(res, key) self._update_state(res, key)
state = self._conversations.get(key) state = self._conversations.get(key)

View file

@ -1050,9 +1050,10 @@ class TestConversationHandler:
await app.stop() await app.stop()
async def test_non_blocking_exception(self, app, bot, user1, caplog): @pytest.mark.parametrize(argnames="test_type", argvalues=["none", "exception"])
async def test_non_blocking_exception_or_none(self, app, bot, user1, caplog, test_type):
"""Here we make sure that when a non-blocking handler raises an """Here we make sure that when a non-blocking handler raises an
exception, the state isn't changed. exception or returns None, the state isn't changed.
""" """
error = Exception("task exception") error = Exception("task exception")
@ -1060,6 +1061,8 @@ class TestConversationHandler:
return 1 return 1
async def raise_error(*a, **kw): async def raise_error(*a, **kw):
if test_type == "none":
return None
raise error raise error
handler = ConversationHandler( handler = ConversationHandler(
@ -1092,12 +1095,57 @@ class TestConversationHandler:
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
# This also makes sure that we're still in the same state # This also makes sure that we're still in the same state
assert handler.check_update(Update(0, message=message)) assert handler.check_update(Update(0, message=message))
if test_type == "exception":
assert len(caplog.records) == 1 assert len(caplog.records) == 1
assert ( assert (
caplog.records[0].message caplog.records[0].message
== "Task function raised exception. Falling back to old state 1" == "Task function raised exception. Falling back to old state 1"
) )
assert caplog.records[0].exc_info[1] is error assert caplog.records[0].exc_info[1] is None
else:
assert len(caplog.records) == 0
async def test_non_blocking_entry_point_exception(self, app, bot, user1, caplog):
"""Here we make sure that when a non-blocking entry point raises an
exception, the state isn't changed.
"""
error = Exception("task exception")
async def raise_error(*a, **kw):
raise error
handler = ConversationHandler(
entry_points=[CommandHandler("start", raise_error, block=False)],
states={},
fallbacks=self.fallbacks,
)
app.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text="/start",
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start"))
],
bot=bot,
)
# start the conversation
async with app:
await app.process_update(Update(update_id=0, message=message))
await asyncio.sleep(0.1)
caplog.clear()
with caplog.at_level(logging.ERROR):
# This also makes sure that we're still in the same state
assert handler.check_update(Update(0, message=message))
assert len(caplog.records) == 1
assert (
caplog.records[0].message
== "Task function raised exception. Falling back to old state None"
)
assert caplog.records[0].exc_info[1] is None
async def test_conversation_timeout(self, app, bot, user1): async def test_conversation_timeout(self, app, bot, user1):
handler = ConversationHandler( handler = ConversationHandler(