Improve Timeouts in ConversationHandler (#2417)

* Handle promise states in conversation timeout

Signed-off-by: starry69 <starry369126@outlook.com>

* warn if nested conversation & timeout

Signed-off-by: starry69 <starry369126@outlook.com>

* Add notes and test for conversation_timeout

Signed-off-by: starry69 <starry369126@outlook.com>

* Try to fix pre-commit

Signed-off-by: starry69 <starry369126@outlook.com>

* Test promise exception

Signed-off-by: starry69 <starry369126@outlook.com>

* Welp

Signed-off-by: starry69 <starry369126@outlook.com>

* improve docs

Signed-off-by: starry69 <starry369126@outlook.com>

* typo

Signed-off-by: starry69 <starry369126@outlook.com>

* try to fix codecov

Signed-off-by: starry69 <starry369126@outlook.com>

* refactor timeout logic with promise.add_done_cb

Signed-off-by: starry69 <starry369126@outlook.com>

* small fix

Signed-off-by: starry69 <starry369126@outlook.com>

* Address review

Signed-off-by: starry69 <starry369126@outlook.com>

* Fix some type hinting

* Few fixes

Signed-off-by: starry69 <starry369126@outlook.com>

* fix tests

Signed-off-by: starry69 <starry369126@outlook.com>

* minor nitpick

Signed-off-by: starry69 <starry369126@outlook.com>

Co-authored-by: Hinrich Mahler <hinrich.mahler@freenet.de>
This commit is contained in:
Stɑrry Shivɑm 2021-04-30 13:40:46 +05:30 committed by GitHub
parent b6a6d7f872
commit 7e554584b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 387 additions and 39 deletions

View file

@ -21,8 +21,10 @@
import logging
import warnings
import functools
import datetime
from threading import Lock
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, cast, ClassVar
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union, Tuple, cast, ClassVar
from telegram import Update
from telegram.ext import (
@ -143,6 +145,13 @@ class ConversationHandler(Handler[Update]):
received update and the corresponding ``context`` will be handled by ALL the handler's
who's :attr:`check_update` method returns :obj:`True` that are in the state
:attr:`ConversationHandler.TIMEOUT`.
Note:
Using `conversation_timeout` with nested conversations is currently not
supported. You can still try to use it, but it will likely behave differently
from what you expect.
name (:obj:`str`, optional): The name for this conversationhandler. Required for
persistence.
persistent (:obj:`bool`, optional): If the conversations dict for this handler should be
@ -215,7 +224,7 @@ class ConversationHandler(Handler[Update]):
per_chat: bool = True,
per_user: bool = True,
per_message: bool = False,
conversation_timeout: int = None,
conversation_timeout: Union[float, datetime.timedelta] = None,
name: str = None,
persistent: bool = False,
map_to_parent: Dict[object, object] = None,
@ -291,6 +300,16 @@ class ConversationHandler(Handler[Update]):
)
break
if self.conversation_timeout:
for handler in all_handlers:
if isinstance(handler, self.__class__):
warnings.warn(
"Using `conversation_timeout` with nested conversations is currently not "
"supported. You can still try to use it, but it will likely behave "
"differently from what you expect."
)
break
if self.run_async:
for handler in all_handlers:
handler.run_async = True
@ -352,7 +371,9 @@ class ConversationHandler(Handler[Update]):
raise ValueError('You can not assign a new value to per_message after initialization.')
@property
def conversation_timeout(self) -> Optional[int]:
def conversation_timeout(
self,
) -> Optional[Union[float, datetime.timedelta]]:
return self._conversation_timeout
@conversation_timeout.setter
@ -423,6 +444,45 @@ class ConversationHandler(Handler[Update]):
return tuple(key)
def _resolve_promise(self, state: Tuple) -> object:
old_state, new_state = state
try:
res = new_state.result(0)
res = res if res is not None else old_state
except Exception as exc:
self.logger.exception("Promise function raised exception")
self.logger.exception("%s", exc)
res = old_state
finally:
if res is None and old_state is None:
res = self.END
return res
def _schedule_job(
self,
new_state: object,
dispatcher: 'Dispatcher',
update: Update,
context: Optional[CallbackContext],
conversation_key: Tuple[int, ...],
) -> None:
if new_state != self.END:
try:
# both job_queue & conversation_timeout are checked before calling _schedule_job
j_queue = dispatcher.job_queue
self.timeout_jobs[conversation_key] = j_queue.run_once( # type: ignore[union-attr]
self._trigger_timeout,
self.conversation_timeout, # type: ignore[arg-type]
context=_ConversationTimeoutContext(
conversation_key, update, dispatcher, context
),
)
except Exception as exc:
self.logger.exception(
"Failed to schedule timeout job due to the following exception:"
)
self.logger.exception("%s", exc)
def check_update(self, update: object) -> CheckUpdateType: # pylint: disable=R0911
"""
Determines whether an update should be handled by this conversationhandler, and if so in
@ -455,21 +515,14 @@ class ConversationHandler(Handler[Update]):
if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise):
self.logger.debug('waiting for promise...')
old_state, new_state = state
if new_state.done.wait(0):
try:
res = new_state.result(0)
res = res if res is not None else old_state
except Exception as exc:
self.logger.exception("Promise function raised exception")
self.logger.exception("%s", exc)
res = old_state
finally:
if res is None and old_state is None:
res = self.END
self.update_state(res, key)
with self._conversations_lock:
state = self.conversations.get(key)
# check if promise is finished or not
if state[1].done.wait(0):
res = self._resolve_promise(state)
self.update_state(res, key)
with self._conversations_lock:
state = self.conversations.get(key)
# if not then handle WAITING state instead
else:
hdlrs = self.states.get(self.WAITING, [])
for hdlr in hdlrs:
@ -551,15 +604,27 @@ class ConversationHandler(Handler[Update]):
new_state = exception.state
raise_dp_handler_stop = True
with self._timeout_jobs_lock:
if self.conversation_timeout and new_state != self.END and dispatcher.job_queue:
# Add the new timeout job
self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once(
self._trigger_timeout, # type: ignore[arg-type]
self.conversation_timeout,
context=_ConversationTimeoutContext(
conversation_key, update, dispatcher, context
),
)
if self.conversation_timeout:
if dispatcher.job_queue is not None:
# Add the new timeout job
if isinstance(new_state, Promise):
new_state.add_done_callback(
functools.partial(
self._schedule_job,
dispatcher=dispatcher,
update=update,
context=context,
conversation_key=conversation_key,
)
)
elif new_state != self.END:
self._schedule_job(
new_state, dispatcher, update, context, conversation_key
)
else:
self.logger.warning(
"Ignoring `conversation_timeout` because the Dispatcher has no JobQueue."
)
if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent:
self.update_state(self.END, conversation_key)
@ -602,35 +667,35 @@ class ConversationHandler(Handler[Update]):
if self.persistent and self.persistence and self.name:
self.persistence.update_conversation(self.name, key, new_state)
def _trigger_timeout(self, context: _ConversationTimeoutContext, job: 'Job' = None) -> None:
def _trigger_timeout(self, context: CallbackContext, job: 'Job' = None) -> None:
self.logger.debug('conversation timeout was triggered!')
# Backward compatibility with bots that do not use CallbackContext
callback_context = None
if isinstance(context, CallbackContext):
job = context.job
ctxt = cast(_ConversationTimeoutContext, job.context) # type: ignore[union-attr]
else:
ctxt = cast(_ConversationTimeoutContext, job.context)
context = job.context # type:ignore[union-attr,assignment]
callback_context = context.callback_context
callback_context = ctxt.callback_context
with self._timeout_jobs_lock:
found_job = self.timeout_jobs[context.conversation_key]
found_job = self.timeout_jobs[ctxt.conversation_key]
if found_job is not job:
# The timeout has been canceled in handle_update
# The timeout has been cancelled in handle_update
return
del self.timeout_jobs[context.conversation_key]
del self.timeout_jobs[ctxt.conversation_key]
handlers = self.states.get(self.TIMEOUT, [])
for handler in handlers:
check = handler.check_update(context.update)
check = handler.check_update(ctxt.update)
if check is not None and check is not False:
try:
handler.handle_update(
context.update, context.dispatcher, check, callback_context
)
handler.handle_update(ctxt.update, ctxt.dispatcher, check, callback_context)
except DispatcherHandlerStop:
self.logger.warning(
'DispatcherHandlerStop in TIMEOUT state of '
'ConversationHandler has no effect. Ignoring.'
)
self.update_state(self.END, context.conversation_key)
self.update_state(self.END, ctxt.conversation_key)

View file

@ -69,6 +69,7 @@ class Promise:
self.update = update
self.error_handling = error_handling
self.done = Event()
self._done_callback: Optional[Callable] = None
self._result: Optional[RT] = None
self._exception: Optional[Exception] = None
@ -83,6 +84,15 @@ class Promise:
finally:
self.done.set()
if self._done_callback:
try:
self._done_callback(self.result())
except Exception as exc:
logger.warning(
"`done_callback` of a Promise raised the following exception."
" The exception won't be handled by error handlers."
)
logger.warning("Full traceback:", exc_info=exc)
def __call__(self) -> None:
self.run()
@ -106,6 +116,20 @@ class Promise:
raise self._exception # pylint: disable=raising-bad-type
return self._result
def add_done_callback(self, callback: Callable) -> None:
"""
Callback to be run when :class:`telegram.ext.utils.promise.Promise` becomes done.
Args:
callback (:obj:`callable`): The callable that will be called when promise is done.
callback will be called by passing ``Promise.result()`` as only positional argument.
"""
if self.done.wait(0):
callback(self.result())
else:
self._done_callback = callback
@property
def exception(self) -> Optional[Exception]:
"""The exception raised by :attr:`pooled_function` or ``None`` if no exception has been

View file

@ -784,6 +784,125 @@ class TestConversationHandler:
assert not handler.check_update(Update(0, pre_checkout_query=pre_checkout_query))
assert not handler.check_update(Update(0, shipping_query=shipping_query))
def test_no_jobqueue_warning(self, dp, bot, user1, caplog):
handler = ConversationHandler(
entry_points=self.entry_points,
states=self.states,
fallbacks=self.fallbacks,
conversation_timeout=0.5,
)
# save dp.job_queue in temp variable jqueue
# and then set dp.job_queue to None.
jqueue = dp.job_queue
dp.job_queue = None
dp.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text='/start',
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
],
bot=bot,
)
with caplog.at_level(logging.WARNING):
dp.process_update(Update(update_id=0, message=message))
sleep(0.5)
assert len(caplog.records) == 1
assert (
caplog.records[0].message
== "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue."
)
# now set dp.job_queue back to it's original value
dp.job_queue = jqueue
def test_schedule_job_exception(self, dp, bot, user1, monkeypatch, caplog):
def mocked_run_once(*a, **kw):
raise Exception("job error")
monkeypatch.setattr(dp.job_queue, "run_once", mocked_run_once)
handler = ConversationHandler(
entry_points=self.entry_points,
states=self.states,
fallbacks=self.fallbacks,
conversation_timeout=100,
)
dp.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text='/start',
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
],
bot=bot,
)
with caplog.at_level(logging.ERROR):
dp.process_update(Update(update_id=0, message=message))
sleep(0.5)
assert len(caplog.records) == 2
assert (
caplog.records[0].message
== "Failed to schedule timeout job due to the following exception:"
)
assert caplog.records[1].message == "job error"
def test_promise_exception(self, dp, bot, user1, caplog):
"""
Here we make sure that when a run_async handle raises an
exception, the state isn't changed.
"""
def conv_entry(*a, **kw):
return 1
def raise_error(*a, **kw):
raise Exception("promise exception")
handler = ConversationHandler(
entry_points=[CommandHandler("start", conv_entry)],
states={1: [MessageHandler(Filters.all, raise_error)]},
fallbacks=self.fallbacks,
run_async=True,
)
dp.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text='/start',
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
],
bot=bot,
)
# start the conversation
dp.process_update(Update(update_id=0, message=message))
sleep(0.1)
message.text = "error"
dp.process_update(Update(update_id=0, message=message))
sleep(0.1)
message.text = "resolve promise pls"
caplog.clear()
with caplog.at_level(logging.ERROR):
dp.process_update(Update(update_id=0, message=message))
sleep(0.5)
assert len(caplog.records) == 3
assert caplog.records[0].message == "Promise function raised exception"
assert caplog.records[1].message == "promise exception"
# assert res is old state
assert handler.conversations.get((self.group.id, user1.id))[0] == 1
def test_conversation_timeout(self, dp, bot, user1):
handler = ConversationHandler(
entry_points=self.entry_points,
@ -820,6 +939,49 @@ class TestConversationHandler:
sleep(0.7)
assert handler.conversations.get((self.group.id, user1.id)) is None
def test_timeout_not_triggered_on_conv_end_async(self, bot, dp, user1):
def timeout(*a, **kw):
self.test_flag = True
self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]})
handler = ConversationHandler(
entry_points=self.entry_points,
states=self.states,
fallbacks=self.fallbacks,
conversation_timeout=0.5,
run_async=True,
)
dp.add_handler(handler)
message = Message(
0,
None,
self.group,
from_user=user1,
text='/start',
entities=[
MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start'))
],
bot=bot,
)
# start the conversation
dp.process_update(Update(update_id=0, message=message))
sleep(0.1)
message.text = '/brew'
message.entities[0].length = len('/brew')
dp.process_update(Update(update_id=1, message=message))
sleep(0.1)
message.text = '/pourCoffee'
message.entities[0].length = len('/pourCoffee')
dp.process_update(Update(update_id=2, message=message))
sleep(0.1)
message.text = '/end'
message.entities[0].length = len('/end')
dp.process_update(Update(update_id=3, message=message))
sleep(1)
# assert timeout handler didn't got called
assert self.test_flag is False
def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog):
handler = ConversationHandler(
entry_points=self.entry_points,
@ -1157,6 +1319,39 @@ class TestConversationHandler:
assert handler.conversations.get((self.group.id, user1.id)) is None
assert self.is_timeout
def test_conversation_timeout_warning_only_shown_once(self, recwarn):
ConversationHandler(
entry_points=self.entry_points,
states={
self.THIRSTY: [
ConversationHandler(
entry_points=self.entry_points,
states={
self.BREWING: [CommandHandler('pourCoffee', self.drink)],
},
fallbacks=self.fallbacks,
)
],
self.DRINKING: [
ConversationHandler(
entry_points=self.entry_points,
states={
self.CODING: [CommandHandler('startCoding', self.code)],
},
fallbacks=self.fallbacks,
)
],
},
fallbacks=self.fallbacks,
conversation_timeout=100,
)
assert len(recwarn) == 1
assert str(recwarn[0].message) == (
"Using `conversation_timeout` with nested conversations is currently not "
"supported. You can still try to use it, but it will likely behave "
"differently from what you expect."
)
def test_per_message_warning_is_only_shown_once(self, recwarn):
ConversationHandler(
entry_points=self.entry_points,

View file

@ -16,6 +16,7 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import logging
import pytest
from telegram import TelegramError
@ -63,3 +64,66 @@ class TestPromise:
with pytest.raises(TelegramError, match='Error'):
promise.result()
def test_done_cb_after_run(self):
def callback():
return "done!"
def done_callback(_):
self.test_flag = True
promise = Promise(callback, [], {})
promise.run()
promise.add_done_callback(done_callback)
assert promise.result() == "done!"
assert self.test_flag is True
def test_done_cb_after_run_excp(self):
def callback():
return "done!"
def done_callback(_):
raise Exception("Error!")
promise = Promise(callback, [], {})
promise.run()
assert promise.result() == "done!"
with pytest.raises(Exception) as err:
promise.add_done_callback(done_callback)
assert str(err) == "Error!"
def test_done_cb_before_run(self):
def callback():
return "done!"
def done_callback(_):
self.test_flag = True
promise = Promise(callback, [], {})
promise.add_done_callback(done_callback)
assert promise.result(0) != "done!"
assert self.test_flag is False
promise.run()
assert promise.result() == "done!"
assert self.test_flag is True
def test_done_cb_before_run_excp(self, caplog):
def callback():
return "done!"
def done_callback(_):
raise Exception("Error!")
promise = Promise(callback, [], {})
promise.add_done_callback(done_callback)
assert promise.result(0) != "done!"
caplog.clear()
with caplog.at_level(logging.WARNING):
promise.run()
assert len(caplog.records) == 2
assert caplog.records[0].message == (
"`done_callback` of a Promise raised the following exception."
" The exception won't be handled by error handlers."
)
assert caplog.records[1].message.startswith("Full traceback:")
assert promise.result() == "done!"