mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Fix Non-Blocking Entry Point in ConversationHandler
(#3068)
This commit is contained in:
parent
42276338b1
commit
22419c0464
2 changed files with 65 additions and 9 deletions
|
@ -92,8 +92,9 @@ class PendingState:
|
||||||
|
|
||||||
def resolve(self) -> object:
|
def resolve(self) -> object:
|
||||||
"""Returns the new state of the :class:`ConversationHandler` if available. If there was an
|
"""Returns the new state of the :class:`ConversationHandler` if available. If there was an
|
||||||
exception during the task execution, then return the old state. If the returned state was
|
exception during the task execution, then return the old state. If both the new and old
|
||||||
:obj:`None`, then end the conversation.
|
state are :obj:`None`, return `CH.END`. If only the new state is :obj:`None`, return the
|
||||||
|
old state.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
:exc:`RuntimeError`: If the current task has not yet finished.
|
:exc:`RuntimeError`: If the current task has not yet finished.
|
||||||
|
@ -106,13 +107,15 @@ class PendingState:
|
||||||
_logger.exception(
|
_logger.exception(
|
||||||
"Task function raised exception. Falling back to old state %s",
|
"Task function raised exception. Falling back to old state %s",
|
||||||
self.old_state,
|
self.old_state,
|
||||||
exc_info=exc,
|
|
||||||
)
|
)
|
||||||
return self.old_state
|
return self.old_state
|
||||||
|
|
||||||
res = self.task.result()
|
res = self.task.result()
|
||||||
if res is None and self.old_state is None:
|
if res is None and self.old_state is None:
|
||||||
res = ConversationHandler.END
|
res = ConversationHandler.END
|
||||||
|
elif res is None:
|
||||||
|
# returning None from a callback means that we want to stay in the old state
|
||||||
|
return self.old_state
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -708,8 +711,13 @@ class ConversationHandler(BaseHandler[Update, CCT]):
|
||||||
# check if future is finished or not
|
# check if future is finished or not
|
||||||
if state.done():
|
if state.done():
|
||||||
res = state.resolve()
|
res = state.resolve()
|
||||||
self._update_state(res, key)
|
# Special case if an error was raised in a non-blocking entry-point
|
||||||
state = self._conversations.get(key)
|
if state.old_state is None and state.task.exception():
|
||||||
|
self._conversations.pop(key, None)
|
||||||
|
state = None
|
||||||
|
else:
|
||||||
|
self._update_state(res, key)
|
||||||
|
state = self._conversations.get(key)
|
||||||
|
|
||||||
# if not then handle WAITING state instead
|
# if not then handle WAITING state instead
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1050,9 +1050,10 @@ class TestConversationHandler:
|
||||||
|
|
||||||
await app.stop()
|
await app.stop()
|
||||||
|
|
||||||
async def test_non_blocking_exception(self, app, bot, user1, caplog):
|
@pytest.mark.parametrize(argnames="test_type", argvalues=["none", "exception"])
|
||||||
|
async def test_non_blocking_exception_or_none(self, app, bot, user1, caplog, test_type):
|
||||||
"""Here we make sure that when a non-blocking handler raises an
|
"""Here we make sure that when a non-blocking handler raises an
|
||||||
exception, the state isn't changed.
|
exception or returns None, the state isn't changed.
|
||||||
"""
|
"""
|
||||||
error = Exception("task exception")
|
error = Exception("task exception")
|
||||||
|
|
||||||
|
@ -1060,6 +1061,8 @@ class TestConversationHandler:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
async def raise_error(*a, **kw):
|
async def raise_error(*a, **kw):
|
||||||
|
if test_type == "none":
|
||||||
|
return None
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
handler = ConversationHandler(
|
handler = ConversationHandler(
|
||||||
|
@ -1086,6 +1089,51 @@ class TestConversationHandler:
|
||||||
await app.process_update(Update(update_id=0, message=message))
|
await app.process_update(Update(update_id=0, message=message))
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
message.text = "error"
|
message.text = "error"
|
||||||
|
await app.process_update(Update(update_id=0, message=message))
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
caplog.clear()
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
# This also makes sure that we're still in the same state
|
||||||
|
assert handler.check_update(Update(0, message=message))
|
||||||
|
if test_type == "exception":
|
||||||
|
assert len(caplog.records) == 1
|
||||||
|
assert (
|
||||||
|
caplog.records[0].message
|
||||||
|
== "Task function raised exception. Falling back to old state 1"
|
||||||
|
)
|
||||||
|
assert caplog.records[0].exc_info[1] is None
|
||||||
|
else:
|
||||||
|
assert len(caplog.records) == 0
|
||||||
|
|
||||||
|
async def test_non_blocking_entry_point_exception(self, app, bot, user1, caplog):
|
||||||
|
"""Here we make sure that when a non-blocking entry point raises an
|
||||||
|
exception, the state isn't changed.
|
||||||
|
"""
|
||||||
|
error = Exception("task exception")
|
||||||
|
|
||||||
|
async def raise_error(*a, **kw):
|
||||||
|
raise error
|
||||||
|
|
||||||
|
handler = ConversationHandler(
|
||||||
|
entry_points=[CommandHandler("start", raise_error, block=False)],
|
||||||
|
states={},
|
||||||
|
fallbacks=self.fallbacks,
|
||||||
|
)
|
||||||
|
app.add_handler(handler)
|
||||||
|
|
||||||
|
message = Message(
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
self.group,
|
||||||
|
from_user=user1,
|
||||||
|
text="/start",
|
||||||
|
entities=[
|
||||||
|
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len("/start"))
|
||||||
|
],
|
||||||
|
bot=bot,
|
||||||
|
)
|
||||||
|
# start the conversation
|
||||||
|
async with app:
|
||||||
await app.process_update(Update(update_id=0, message=message))
|
await app.process_update(Update(update_id=0, message=message))
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
@ -1095,9 +1143,9 @@ class TestConversationHandler:
|
||||||
assert len(caplog.records) == 1
|
assert len(caplog.records) == 1
|
||||||
assert (
|
assert (
|
||||||
caplog.records[0].message
|
caplog.records[0].message
|
||||||
== "Task function raised exception. Falling back to old state 1"
|
== "Task function raised exception. Falling back to old state None"
|
||||||
)
|
)
|
||||||
assert caplog.records[0].exc_info[1] is error
|
assert caplog.records[0].exc_info[1] is None
|
||||||
|
|
||||||
async def test_conversation_timeout(self, app, bot, user1):
|
async def test_conversation_timeout(self, app, bot, user1):
|
||||||
handler = ConversationHandler(
|
handler = ConversationHandler(
|
||||||
|
|
Loading…
Reference in a new issue