mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-16 20:29:55 +01:00
Improve Timeouts in ConversationHandler (#2417)
* Handle promise states in conversation timeout Signed-off-by: starry69 <starry369126@outlook.com> * warn if nested conversation & timeout Signed-off-by: starry69 <starry369126@outlook.com> * Add notes and test for conversation_timeout Signed-off-by: starry69 <starry369126@outlook.com> * Try to fix pre-commit Signed-off-by: starry69 <starry369126@outlook.com> * Test promise exception Signed-off-by: starry69 <starry369126@outlook.com> * Welp Signed-off-by: starry69 <starry369126@outlook.com> * improve docs Signed-off-by: starry69 <starry369126@outlook.com> * typo Signed-off-by: starry69 <starry369126@outlook.com> * try to fix codecov Signed-off-by: starry69 <starry369126@outlook.com> * refactor timeout logic with promise.add_done_cb Signed-off-by: starry69 <starry369126@outlook.com> * small fix Signed-off-by: starry69 <starry369126@outlook.com> * Address review Signed-off-by: starry69 <starry369126@outlook.com> * Fix some type hinting * Few fixes Signed-off-by: starry69 <starry369126@outlook.com> * fix tests Signed-off-by: starry69 <starry369126@outlook.com> * minor nitpick Signed-off-by: starry69 <starry369126@outlook.com> Co-authored-by: Hinrich Mahler <hinrich.mahler@freenet.de>
This commit is contained in:
parent
b6a6d7f872
commit
7e554584b1
4 changed files with 387 additions and 39 deletions
|
@ -21,8 +21,10 @@
|
|||
|
||||
import logging
|
||||
import warnings
|
||||
import functools
|
||||
import datetime
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, cast, ClassVar
|
||||
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union, Tuple, cast, ClassVar
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import (
|
||||
|
@ -143,6 +145,13 @@ class ConversationHandler(Handler[Update]):
|
|||
received update and the corresponding ``context`` will be handled by ALL the handler's
|
||||
who's :attr:`check_update` method returns :obj:`True` that are in the state
|
||||
:attr:`ConversationHandler.TIMEOUT`.
|
||||
|
||||
Note:
|
||||
Using `conversation_timeout` with nested conversations is currently not
|
||||
supported. You can still try to use it, but it will likely behave differently
|
||||
from what you expect.
|
||||
|
||||
|
||||
name (:obj:`str`, optional): The name for this conversationhandler. Required for
|
||||
persistence.
|
||||
persistent (:obj:`bool`, optional): If the conversations dict for this handler should be
|
||||
|
@ -215,7 +224,7 @@ class ConversationHandler(Handler[Update]):
|
|||
per_chat: bool = True,
|
||||
per_user: bool = True,
|
||||
per_message: bool = False,
|
||||
conversation_timeout: int = None,
|
||||
conversation_timeout: Union[float, datetime.timedelta] = None,
|
||||
name: str = None,
|
||||
persistent: bool = False,
|
||||
map_to_parent: Dict[object, object] = None,
|
||||
|
@ -291,6 +300,16 @@ class ConversationHandler(Handler[Update]):
|
|||
)
|
||||
break
|
||||
|
||||
if self.conversation_timeout:
|
||||
for handler in all_handlers:
|
||||
if isinstance(handler, self.__class__):
|
||||
warnings.warn(
|
||||
"Using `conversation_timeout` with nested conversations is currently not "
|
||||
"supported. You can still try to use it, but it will likely behave "
|
||||
"differently from what you expect."
|
||||
)
|
||||
break
|
||||
|
||||
if self.run_async:
|
||||
for handler in all_handlers:
|
||||
handler.run_async = True
|
||||
|
@ -352,7 +371,9 @@ class ConversationHandler(Handler[Update]):
|
|||
raise ValueError('You can not assign a new value to per_message after initialization.')
|
||||
|
||||
@property
|
||||
def conversation_timeout(self) -> Optional[int]:
|
||||
def conversation_timeout(
|
||||
self,
|
||||
) -> Optional[Union[float, datetime.timedelta]]:
|
||||
return self._conversation_timeout
|
||||
|
||||
@conversation_timeout.setter
|
||||
|
@ -423,6 +444,45 @@ class ConversationHandler(Handler[Update]):
|
|||
|
||||
return tuple(key)
|
||||
|
||||
def _resolve_promise(self, state: Tuple) -> object:
|
||||
old_state, new_state = state
|
||||
try:
|
||||
res = new_state.result(0)
|
||||
res = res if res is not None else old_state
|
||||
except Exception as exc:
|
||||
self.logger.exception("Promise function raised exception")
|
||||
self.logger.exception("%s", exc)
|
||||
res = old_state
|
||||
finally:
|
||||
if res is None and old_state is None:
|
||||
res = self.END
|
||||
return res
|
||||
|
||||
def _schedule_job(
|
||||
self,
|
||||
new_state: object,
|
||||
dispatcher: 'Dispatcher',
|
||||
update: Update,
|
||||
context: Optional[CallbackContext],
|
||||
conversation_key: Tuple[int, ...],
|
||||
) -> None:
|
||||
if new_state != self.END:
|
||||
try:
|
||||
# both job_queue & conversation_timeout are checked before calling _schedule_job
|
||||
j_queue = dispatcher.job_queue
|
||||
self.timeout_jobs[conversation_key] = j_queue.run_once( # type: ignore[union-attr]
|
||||
self._trigger_timeout,
|
||||
self.conversation_timeout, # type: ignore[arg-type]
|
||||
context=_ConversationTimeoutContext(
|
||||
conversation_key, update, dispatcher, context
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.exception(
|
||||
"Failed to schedule timeout job due to the following exception:"
|
||||
)
|
||||
self.logger.exception("%s", exc)
|
||||
|
||||
def check_update(self, update: object) -> CheckUpdateType: # pylint: disable=R0911
|
||||
"""
|
||||
Determines whether an update should be handled by this conversationhandler, and if so in
|
||||
|
@ -455,21 +515,14 @@ class ConversationHandler(Handler[Update]):
|
|||
if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise):
|
||||
self.logger.debug('waiting for promise...')
|
||||
|
||||
old_state, new_state = state
|
||||
if new_state.done.wait(0):
|
||||
try:
|
||||
res = new_state.result(0)
|
||||
res = res if res is not None else old_state
|
||||
except Exception as exc:
|
||||
self.logger.exception("Promise function raised exception")
|
||||
self.logger.exception("%s", exc)
|
||||
res = old_state
|
||||
finally:
|
||||
if res is None and old_state is None:
|
||||
res = self.END
|
||||
self.update_state(res, key)
|
||||
with self._conversations_lock:
|
||||
state = self.conversations.get(key)
|
||||
# check if promise is finished or not
|
||||
if state[1].done.wait(0):
|
||||
res = self._resolve_promise(state)
|
||||
self.update_state(res, key)
|
||||
with self._conversations_lock:
|
||||
state = self.conversations.get(key)
|
||||
|
||||
# if not then handle WAITING state instead
|
||||
else:
|
||||
hdlrs = self.states.get(self.WAITING, [])
|
||||
for hdlr in hdlrs:
|
||||
|
@ -551,15 +604,27 @@ class ConversationHandler(Handler[Update]):
|
|||
new_state = exception.state
|
||||
raise_dp_handler_stop = True
|
||||
with self._timeout_jobs_lock:
|
||||
if self.conversation_timeout and new_state != self.END and dispatcher.job_queue:
|
||||
# Add the new timeout job
|
||||
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
|
||||
self._trigger_timeout, # type: ignore[arg-type]
|
||||
self.conversation_timeout,
|
||||
context=_ConversationTimeoutContext(
|
||||
conversation_key, update, dispatcher, context
|
||||
),
|
||||
)
|
||||
if self.conversation_timeout:
|
||||
if dispatcher.job_queue is not None:
|
||||
# Add the new timeout job
|
||||
if isinstance(new_state, Promise):
|
||||
new_state.add_done_callback(
|
||||
functools.partial(
|
||||
self._schedule_job,
|
||||
dispatcher=dispatcher,
|
||||
update=update,
|
||||
context=context,
|
||||
conversation_key=conversation_key,
|
||||
)
|
||||
)
|
||||
elif new_state != self.END:
|
||||
self._schedule_job(
|
||||
new_state, dispatcher, update, context, conversation_key
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Ignoring `conversation_timeout` because the Dispatcher has no JobQueue."
|
||||
)
|
||||
|
||||
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
|
||||
self.update_state(self.END, conversation_key)
|
||||
|
@ -602,35 +667,35 @@ class ConversationHandler(Handler[Update]):
|
|||
if self.persistent and self.persistence and self.name:
|
||||
self.persistence.update_conversation(self.name, key, new_state)
|
||||
|
||||
def _trigger_timeout(self, context: _ConversationTimeoutContext, job: 'Job' = None) -> None:
|
||||
def _trigger_timeout(self, context: CallbackContext, job: 'Job' = None) -> None:
|
||||
self.logger.debug('conversation timeout was triggered!')
|
||||
|
||||
# Backward compatibility with bots that do not use CallbackContext
|
||||
callback_context = None
|
||||
if isinstance(context, CallbackContext):
|
||||
job = context.job
|
||||
ctxt = cast(_ConversationTimeoutContext, job.context) # type: ignore[union-attr]
|
||||
else:
|
||||
ctxt = cast(_ConversationTimeoutContext, job.context)
|
||||
|
||||
context = job.context # type:ignore[union-attr,assignment]
|
||||
callback_context = context.callback_context
|
||||
callback_context = ctxt.callback_context
|
||||
|
||||
with self._timeout_jobs_lock:
|
||||
found_job = self.timeout_jobs[context.conversation_key]
|
||||
found_job = self.timeout_jobs[ctxt.conversation_key]
|
||||
if found_job is not job:
|
||||
# The timeout has been canceled in handle_update
|
||||
# The timeout has been cancelled in handle_update
|
||||
return
|
||||
del self.timeout_jobs[context.conversation_key]
|
||||
del self.timeout_jobs[ctxt.conversation_key]
|
||||
|
||||
handlers = self.states.get(self.TIMEOUT, [])
|
||||
for handler in handlers:
|
||||
check = handler.check_update(context.update)
|
||||
check = handler.check_update(ctxt.update)
|
||||
if check is not None and check is not False:
|
||||
try:
|
||||
handler.handle_update(
|
||||
context.update, context.dispatcher, check, callback_context
|
||||
)
|
||||
handler.handle_update(ctxt.update, ctxt.dispatcher, check, callback_context)
|
||||
except DispatcherHandlerStop:
|
||||
self.logger.warning(
|
||||
'DispatcherHandlerStop in TIMEOUT state of '
|
||||
'ConversationHandler has no effect. Ignoring.'
|
||||
)
|
||||
self.update_state(self.END, context.conversation_key)
|
||||
|
||||
self.update_state(self.END, ctxt.conversation_key)
|
||||
|
|
|
@ -69,6 +69,7 @@ class Promise:
|
|||
self.update = update
|
||||
self.error_handling = error_handling
|
||||
self.done = Event()
|
||||
self._done_callback: Optional[Callable] = None
|
||||
self._result: Optional[RT] = None
|
||||
self._exception: Optional[Exception] = None
|
||||
|
||||
|
@ -83,6 +84,15 @@ class Promise:
|
|||
|
||||
finally:
|
||||
self.done.set()
|
||||
if self._done_callback:
|
||||
try:
|
||||
self._done_callback(self.result())
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"`done_callback` of a Promise raised the following exception."
|
||||
" The exception won't be handled by error handlers."
|
||||
)
|
||||
logger.warning("Full traceback:", exc_info=exc)
|
||||
|
||||
def __call__(self) -> None:
|
||||
self.run()
|
||||
|
@ -106,6 +116,20 @@ class Promise:
|
|||
raise self._exception # pylint: disable=raising-bad-type
|
||||
return self._result
|
||||
|
||||
def add_done_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
Callback to be run when :class:`telegram.ext.utils.promise.Promise` becomes done.
|
||||
|
||||
Args:
|
||||
callback (:obj:`callable`): The callable that will be called when promise is done.
|
||||
callback will be called by passing ``Promise.result()`` as only positional argument.
|
||||
|
||||
"""
|
||||
if self.done.wait(0):
|
||||
callback(self.result())
|
||||
else:
|
||||
self._done_callback = callback
|
||||
|
||||
@property
|
||||
def exception(self) -> Optional[Exception]:
|
||||
"""The exception raised by :attr:`pooled_function` or ``None`` if no exception has been
|
||||
|
|
|
@ -784,6 +784,125 @@ class TestConversationHandler:
|
|||
assert not handler.check_update(Update(0, pre_checkout_query=pre_checkout_query))
|
||||
assert not handler.check_update(Update(0, shipping_query=shipping_query))
|
||||
|
||||
def test_no_jobqueue_warning(self, dp, bot, user1, caplog):
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states=self.states,
|
||||
fallbacks=self.fallbacks,
|
||||
conversation_timeout=0.5,
|
||||
)
|
||||
# save dp.job_queue in temp variable jqueue
|
||||
# and then set dp.job_queue to None.
|
||||
jqueue = dp.job_queue
|
||||
dp.job_queue = None
|
||||
dp.add_handler(handler)
|
||||
|
||||
message = Message(
|
||||
0,
|
||||
None,
|
||||
self.group,
|
||||
from_user=user1,
|
||||
text='/start',
|
||||
entities=[
|
||||
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
||||
],
|
||||
bot=bot,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.5)
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[0].message
|
||||
== "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue."
|
||||
)
|
||||
# now set dp.job_queue back to it's original value
|
||||
dp.job_queue = jqueue
|
||||
|
||||
def test_schedule_job_exception(self, dp, bot, user1, monkeypatch, caplog):
|
||||
def mocked_run_once(*a, **kw):
|
||||
raise Exception("job error")
|
||||
|
||||
monkeypatch.setattr(dp.job_queue, "run_once", mocked_run_once)
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states=self.states,
|
||||
fallbacks=self.fallbacks,
|
||||
conversation_timeout=100,
|
||||
)
|
||||
dp.add_handler(handler)
|
||||
|
||||
message = Message(
|
||||
0,
|
||||
None,
|
||||
self.group,
|
||||
from_user=user1,
|
||||
text='/start',
|
||||
entities=[
|
||||
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
||||
],
|
||||
bot=bot,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.5)
|
||||
assert len(caplog.records) == 2
|
||||
assert (
|
||||
caplog.records[0].message
|
||||
== "Failed to schedule timeout job due to the following exception:"
|
||||
)
|
||||
assert caplog.records[1].message == "job error"
|
||||
|
||||
def test_promise_exception(self, dp, bot, user1, caplog):
|
||||
"""
|
||||
Here we make sure that when a run_async handle raises an
|
||||
exception, the state isn't changed.
|
||||
"""
|
||||
|
||||
def conv_entry(*a, **kw):
|
||||
return 1
|
||||
|
||||
def raise_error(*a, **kw):
|
||||
raise Exception("promise exception")
|
||||
|
||||
handler = ConversationHandler(
|
||||
entry_points=[CommandHandler("start", conv_entry)],
|
||||
states={1: [MessageHandler(Filters.all, raise_error)]},
|
||||
fallbacks=self.fallbacks,
|
||||
run_async=True,
|
||||
)
|
||||
dp.add_handler(handler)
|
||||
|
||||
message = Message(
|
||||
0,
|
||||
None,
|
||||
self.group,
|
||||
from_user=user1,
|
||||
text='/start',
|
||||
entities=[
|
||||
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
||||
],
|
||||
bot=bot,
|
||||
)
|
||||
# start the conversation
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.1)
|
||||
message.text = "error"
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.1)
|
||||
message.text = "resolve promise pls"
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.ERROR):
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.5)
|
||||
assert len(caplog.records) == 3
|
||||
assert caplog.records[0].message == "Promise function raised exception"
|
||||
assert caplog.records[1].message == "promise exception"
|
||||
# assert res is old state
|
||||
assert handler.conversations.get((self.group.id, user1.id))[0] == 1
|
||||
|
||||
def test_conversation_timeout(self, dp, bot, user1):
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
|
@ -820,6 +939,49 @@ class TestConversationHandler:
|
|||
sleep(0.7)
|
||||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
|
||||
def test_timeout_not_triggered_on_conv_end_async(self, bot, dp, user1):
|
||||
def timeout(*a, **kw):
|
||||
self.test_flag = True
|
||||
|
||||
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states=self.states,
|
||||
fallbacks=self.fallbacks,
|
||||
conversation_timeout=0.5,
|
||||
run_async=True,
|
||||
)
|
||||
dp.add_handler(handler)
|
||||
|
||||
message = Message(
|
||||
0,
|
||||
None,
|
||||
self.group,
|
||||
from_user=user1,
|
||||
text='/start',
|
||||
entities=[
|
||||
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
|
||||
],
|
||||
bot=bot,
|
||||
)
|
||||
# start the conversation
|
||||
dp.process_update(Update(update_id=0, message=message))
|
||||
sleep(0.1)
|
||||
message.text = '/brew'
|
||||
message.entities[0].length = len('/brew')
|
||||
dp.process_update(Update(update_id=1, message=message))
|
||||
sleep(0.1)
|
||||
message.text = '/pourCoffee'
|
||||
message.entities[0].length = len('/pourCoffee')
|
||||
dp.process_update(Update(update_id=2, message=message))
|
||||
sleep(0.1)
|
||||
message.text = '/end'
|
||||
message.entities[0].length = len('/end')
|
||||
dp.process_update(Update(update_id=3, message=message))
|
||||
sleep(1)
|
||||
# assert timeout handler didn't got called
|
||||
assert self.test_flag is False
|
||||
|
||||
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
|
@ -1157,6 +1319,39 @@ class TestConversationHandler:
|
|||
assert handler.conversations.get((self.group.id, user1.id)) is None
|
||||
assert self.is_timeout
|
||||
|
||||
def test_conversation_timeout_warning_only_shown_once(self, recwarn):
|
||||
ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states={
|
||||
self.THIRSTY: [
|
||||
ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states={
|
||||
self.BREWING: [CommandHandler('pourCoffee', self.drink)],
|
||||
},
|
||||
fallbacks=self.fallbacks,
|
||||
)
|
||||
],
|
||||
self.DRINKING: [
|
||||
ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states={
|
||||
self.CODING: [CommandHandler('startCoding', self.code)],
|
||||
},
|
||||
fallbacks=self.fallbacks,
|
||||
)
|
||||
],
|
||||
},
|
||||
fallbacks=self.fallbacks,
|
||||
conversation_timeout=100,
|
||||
)
|
||||
assert len(recwarn) == 1
|
||||
assert str(recwarn[0].message) == (
|
||||
"Using `conversation_timeout` with nested conversations is currently not "
|
||||
"supported. You can still try to use it, but it will likely behave "
|
||||
"differently from what you expect."
|
||||
)
|
||||
|
||||
def test_per_message_warning_is_only_shown_once(self, recwarn):
|
||||
ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
from telegram import TelegramError
|
||||
|
@ -63,3 +64,66 @@ class TestPromise:
|
|||
|
||||
with pytest.raises(TelegramError, match='Error'):
|
||||
promise.result()
|
||||
|
||||
def test_done_cb_after_run(self):
|
||||
def callback():
|
||||
return "done!"
|
||||
|
||||
def done_callback(_):
|
||||
self.test_flag = True
|
||||
|
||||
promise = Promise(callback, [], {})
|
||||
promise.run()
|
||||
promise.add_done_callback(done_callback)
|
||||
assert promise.result() == "done!"
|
||||
assert self.test_flag is True
|
||||
|
||||
def test_done_cb_after_run_excp(self):
|
||||
def callback():
|
||||
return "done!"
|
||||
|
||||
def done_callback(_):
|
||||
raise Exception("Error!")
|
||||
|
||||
promise = Promise(callback, [], {})
|
||||
promise.run()
|
||||
assert promise.result() == "done!"
|
||||
with pytest.raises(Exception) as err:
|
||||
promise.add_done_callback(done_callback)
|
||||
assert str(err) == "Error!"
|
||||
|
||||
def test_done_cb_before_run(self):
|
||||
def callback():
|
||||
return "done!"
|
||||
|
||||
def done_callback(_):
|
||||
self.test_flag = True
|
||||
|
||||
promise = Promise(callback, [], {})
|
||||
promise.add_done_callback(done_callback)
|
||||
assert promise.result(0) != "done!"
|
||||
assert self.test_flag is False
|
||||
promise.run()
|
||||
assert promise.result() == "done!"
|
||||
assert self.test_flag is True
|
||||
|
||||
def test_done_cb_before_run_excp(self, caplog):
|
||||
def callback():
|
||||
return "done!"
|
||||
|
||||
def done_callback(_):
|
||||
raise Exception("Error!")
|
||||
|
||||
promise = Promise(callback, [], {})
|
||||
promise.add_done_callback(done_callback)
|
||||
assert promise.result(0) != "done!"
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.WARNING):
|
||||
promise.run()
|
||||
assert len(caplog.records) == 2
|
||||
assert caplog.records[0].message == (
|
||||
"`done_callback` of a Promise raised the following exception."
|
||||
" The exception won't be handled by error handlers."
|
||||
)
|
||||
assert caplog.records[1].message.startswith("Full traceback:")
|
||||
assert promise.result() == "done!"
|
||||
|
|
Loading…
Add table
Reference in a new issue