Adjust Calling of Dispatcher.update_persistence (#2285)

* Adjust calling of update_persistence

(cherry picked from commit 89c522d883)

* Fix tests and stuff
This commit is contained in:
Bibo-Joshi 2021-01-07 21:31:00 +01:00 committed by GitHub
parent 6a831f926b
commit ffd675daec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 18 deletions

View file

@ -173,7 +173,7 @@ class Defaults:
)
@property
def run_async(self) -> Optional[bool]:
def run_async(self) -> bool:
return self._run_async
@run_async.setter

View file

@ -402,11 +402,17 @@ class Dispatcher:
def has_running_threads(self) -> bool:
return self.running or bool(self.__async_threads)
def process_update(self, update: Union[str, Update, TelegramError]) -> None:
"""Processes a single update.
def process_update(self, update: Any) -> None:
"""Processes a single update and updates the persistence.
Note:
If the update is handled by least one synchronously running handlers (i.e.
``run_async=False`), :meth:`update_persistence` is called *once* after all handlers
synchronous handlers are done. Each asynchronously running handler will trigger
:meth:`update_persistence` on its own.
Args:
update (:obj:`str` | :class:`telegram.Update` | :class:`telegram.TelegramError`):
update (:class:`telegram.Update` | :obj:`object` | :class:`telegram.TelegramError`):
The update to process.
"""
@ -420,6 +426,8 @@ class Dispatcher:
return
context = None
handled = False
sync_modes = []
for group in self.groups:
try:
@ -428,11 +436,9 @@ class Dispatcher:
if check is not None and check is not False:
if not context and self.use_context:
context = CallbackContext.from_update(update, self)
handled = True
sync_modes.append(handler.run_async)
handler.handle_update(update, self, check, context)
# If handler runs async updating immediately doesn't make sense
if not handler.run_async:
self.update_persistence(update=update)
break
# Stop processing with any other handler.
@ -452,6 +458,16 @@ class Dispatcher:
except Exception:
self.logger.exception('An uncaught error was raised while handling the error.')
# Update persistence, if handled
handled_only_async = all(sync_modes)
if handled:
# Respect default settings
if all(mode is DEFAULT_FALSE for mode in sync_modes) and self.bot.defaults:
handled_only_async = self.bot.defaults.run_async
# If update was only handled by async handlers, we don't need to update here
if not handled_only_async:
self.update_persistence(update=update)
def add_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None:
"""Register a handler.

View file

@ -35,6 +35,7 @@ from telegram.ext import (
)
from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DEFAULT_FALSE
from tests.conftest import create_dp
from collections import defaultdict
@ -804,3 +805,96 @@ class TestDispatcher:
assert cdp.persistence.test_flag_bot_data
assert not cdp.persistence.test_flag_user_data
assert cdp.persistence.test_flag_chat_data
def test_update_persistence_once_per_update(self, monkeypatch, dp):
def update_persistence(*args, **kwargs):
self.count += 1
def dummy_callback(*args):
pass
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
for group in range(5):
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None))
dp.process_update(update)
assert self.count == 0
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text'))
dp.process_update(update)
assert self.count == 1
def test_update_persistence_all_async(self, monkeypatch, dp):
def update_persistence(*args, **kwargs):
self.count += 1
def dummy_callback(*args, **kwargs):
pass
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
monkeypatch.setattr(dp, 'run_async', dummy_callback)
for group in range(5):
dp.add_handler(
MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
)
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
dp.process_update(update)
assert self.count == 0
dp.bot.defaults = Defaults(run_async=True)
try:
for group in range(5):
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
dp.process_update(update)
assert self.count == 0
finally:
dp.bot.defaults = None
@pytest.mark.parametrize('run_async', [DEFAULT_FALSE, False])
def test_update_persistence_one_sync(self, monkeypatch, dp, run_async):
def update_persistence(*args, **kwargs):
self.count += 1
def dummy_callback(*args, **kwargs):
pass
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
monkeypatch.setattr(dp, 'run_async', dummy_callback)
for group in range(5):
dp.add_handler(
MessageHandler(Filters.text, dummy_callback, run_async=True), group=group
)
dp.add_handler(MessageHandler(Filters.text, dummy_callback, run_async=run_async), group=5)
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
dp.process_update(update)
assert self.count == 1
@pytest.mark.parametrize('run_async,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)])
def test_update_persistence_defaults_async(self, monkeypatch, dp, run_async, expected):
def update_persistence(*args, **kwargs):
self.count += 1
def dummy_callback(*args, **kwargs):
pass
monkeypatch.setattr(dp, 'update_persistence', update_persistence)
monkeypatch.setattr(dp, 'run_async', dummy_callback)
dp.bot.defaults = Defaults(run_async=run_async)
try:
for group in range(5):
dp.add_handler(MessageHandler(Filters.text, dummy_callback), group=group)
update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text'))
dp.process_update(update)
assert self.count == expected
finally:
dp.bot.defaults = None

View file

@ -314,7 +314,7 @@ class TestJobQueue:
next_months_days = calendar.monthrange(now.year, now.month + 1)[1]
expected_reschedule_time += dtm.timedelta(this_months_days)
if next_months_days < this_months_days:
if day > next_months_days:
expected_reschedule_time += dtm.timedelta(next_months_days)
expected_reschedule_time = timezone.normalize(expected_reschedule_time)

View file

@ -1406,18 +1406,21 @@ class TestMessage:
def test_default_quote(self, message):
message.bot.defaults = Defaults()
message.bot.defaults._quote = False
assert message._quote(None, None) is None
try:
message.bot.defaults._quote = False
assert message._quote(None, None) is None
message.bot.defaults._quote = True
assert message._quote(None, None) == message.message_id
message.bot.defaults._quote = True
assert message._quote(None, None) == message.message_id
message.bot.defaults._quote = None
message.chat.type = Chat.PRIVATE
assert message._quote(None, None) is None
message.bot.defaults._quote = None
message.chat.type = Chat.PRIVATE
assert message._quote(None, None) is None
message.chat.type = Chat.GROUP
assert message._quote(None, None)
message.chat.type = Chat.GROUP
assert message._quote(None, None)
finally:
message.bot.defaults = None
def test_equality(self):
id_ = 1