From 5128748092da3b6f7d827e01019327d9baf63fa5 Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 17 Aug 2023 11:50:26 +0200 Subject: [PATCH] Add `Application.stop_running()` and Improve Marking Updates as Read on `Updater.stop()` (#3804) --- .../source/inclusions/application_run_tip.rst | 14 +- telegram/ext/_application.py | 20 ++- telegram/ext/_updater.py | 33 ++++ tests/ext/test_application.py | 158 ++++++++++++++++-- tests/ext/test_updater.py | 131 +++++++++++++-- 5 files changed, 327 insertions(+), 29 deletions(-) diff --git a/docs/source/inclusions/application_run_tip.rst b/docs/source/inclusions/application_run_tip.rst index f6b1261b2..91dcaa997 100644 --- a/docs/source/inclusions/application_run_tip.rst +++ b/docs/source/inclusions/application_run_tip.rst @@ -1,7 +1,9 @@ .. tip:: - When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this - method is likely not the best choice, as it blocks the event loop until it receives a stop - signal as described above. - Instead, you can manually call the methods listed below to start and shut down the application - and the :attr:`~telegram.ext.Application.updater`. - Keeping the event loop running and listening for a stop signal is then up to you. \ No newline at end of file + * When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this + method is likely not the best choice, as it blocks the event loop until it receives a stop + signal as described above. + Instead, you can manually call the methods listed below to start and shut down the application + and the :attr:`~telegram.ext.Application.updater`. + Keeping the event loop running and listening for a stop signal is then up to you. + * To gracefully stop the execution of this method from within a handler, job or error callback, + use :meth:`~telegram.ext.Application.stop_running`. \ No newline at end of file diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 3567f8ac3..36ee63316 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -653,6 +653,24 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica _LOGGER.info("Application.stop() complete") + def stop_running(self) -> None: + """This method can be used to stop the execution of :meth:`run_polling` or + :meth:`run_webhook` from within a handler, job or error callback. This allows a graceful + shutdown of the application, i.e. the methods listed in :attr:`run_polling` and + :attr:`run_webhook` will still be executed. + + Note: + If the application is not running, this method does nothing. + + .. versionadded:: NEXT.VERSION + """ + if self.running: + # This works because `__run` is using `loop.run_forever()`. If that changes, this + # method needs to be adapted. + asyncio.get_running_loop().stop() + else: + _LOGGER.debug("Application is not running, stop_running() does nothing.") + def run_polling( self, poll_interval: float = 0.0, @@ -939,7 +957,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica loop.run_until_complete(self.start()) loop.run_forever() except (KeyboardInterrupt, SystemExit): - pass + _LOGGER.debug("Application received stop signal. Shutting down.") except Exception as exc: # In case the coroutine wasn't awaited, we don't need to bother the user with a warning updater_coroutine.close() diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 5c0da22a7..f89498e06 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -24,6 +24,7 @@ from pathlib import Path from types import TracebackType from typing import ( TYPE_CHECKING, + Any, AsyncContextManager, Callable, Coroutine, @@ -121,6 +122,7 @@ class Updater(AsyncContextManager["Updater"]): self._httpd: Optional[WebhookServer] = None self.__lock = asyncio.Lock() self.__polling_task: Optional[asyncio.Task] = None + self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None @property def running(self) -> bool: @@ -367,6 +369,28 @@ class Updater(AsyncContextManager["Updater"]): name="Updater:start_polling:polling_task", ) + # Prepare a cleanup callback to await on _stop_polling + # Calling get_updates one more time with the latest `offset` parameter ensures that + # all updates that where put into the update queue are also marked as "read" to TG, + # so we do not receive them again on the next startup + # We define this here so that we can use the same parameters as in the polling task + async def _get_updates_cleanup() -> None: + _LOGGER.debug( + "Calling `get_updates` one more time to mark all fetched updates as read." + ) + await self.bot.get_updates( + offset=self._last_update_id, + # We don't want to do long polling here! + timeout=0, + read_timeout=read_timeout, + connect_timeout=connect_timeout, + write_timeout=write_timeout, + pool_timeout=pool_timeout, + allowed_updates=allowed_updates, + ) + + self.__polling_cleanup_cb = _get_updates_cleanup + if ready is not None: ready.set() @@ -748,3 +772,12 @@ class Updater(AsyncContextManager["Updater"]): # after start_polling(), but lets better be safe than sorry ... self.__polling_task = None + + if self.__polling_cleanup_cb: + await self.__polling_cleanup_cb() + self.__polling_cleanup_cb = None + else: + _LOGGER.warning( + "No polling cleanup callback defined. The last fetched updates may be " + "fetched again on the next polling start." + ) diff --git a/tests/ext/test_application.py b/tests/ext/test_application.py index 54cd97133..80642be08 100644 --- a/tests/ext/test_application.py +++ b/tests/ext/test_application.py @@ -1427,7 +1427,7 @@ class TestApplication: platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) - def test_run_polling_basic(self, app, monkeypatch): + def test_run_polling_basic(self, app, monkeypatch, caplog): exception_event = threading.Event() update_event = threading.Event() exception = TelegramError("This is a test error") @@ -1464,6 +1464,9 @@ class TestApplication: time.sleep(0.05) assertions["exception_handling"] = self.received == exception.message + # So that the get_updates call on shutdown doesn't fail + exception_event.clear() + os.kill(os.getpid(), signal.SIGINT) time.sleep(0.1) @@ -1478,13 +1481,20 @@ class TestApplication: thread = Thread(target=thread_target) thread.start() - app.run_polling(drop_pending_updates=True, close_loop=False) - thread.join() + with caplog.at_level(logging.DEBUG): + app.run_polling(drop_pending_updates=True, close_loop=False) + thread.join() assert len(assertions) == 8 for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + found_log = False + for record in caplog.records: + if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG: + found_log = True + assert found_log + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", @@ -1692,7 +1702,7 @@ class TestApplication: platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", ) - def test_run_webhook_basic(self, app, monkeypatch): + def test_run_webhook_basic(self, app, monkeypatch, caplog): assertions = {} async def delete_webhook(*args, **kwargs): @@ -1741,19 +1751,26 @@ class TestApplication: ip = "127.0.0.1" port = randrange(1024, 49152) - app.run_webhook( - ip_address=ip, - port=port, - url_path="TOKEN", - drop_pending_updates=True, - close_loop=False, - ) - thread.join() + with caplog.at_level(logging.DEBUG): + app.run_webhook( + ip_address=ip, + port=port, + url_path="TOKEN", + drop_pending_updates=True, + close_loop=False, + ) + thread.join() assert len(assertions) == 7 for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + found_log = False + for record in caplog.records: + if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG: + found_log = True + assert found_log + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", @@ -2226,3 +2243,120 @@ class TestApplication: assert received_signals == [] else: assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] + + def test_stop_running_not_running(self, app, caplog): + with caplog.at_level(logging.DEBUG): + app.stop_running() + + assert len(caplog.records) == 1 + assert caplog.records[-1].name == "telegram.ext.Application" + assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.") + + @pytest.mark.parametrize("method", ["polling", "webhook"]) + def test_stop_running(self, one_time_bot, monkeypatch, method): + # asyncio.Event() seems to be hard to use across different threads (awaiting in main + # thread, setting in another thread), so we use threading.Event() instead. + # This requires the use of run_in_executor, but that's fine. + put_update_event = threading.Event() + callback_done_event = threading.Event() + called_stop_running = threading.Event() + assertions = {} + + async def get_updates(*args, **kwargs): + await asyncio.sleep(0) + return [] + + async def delete_webhook(*args, **kwargs): + return True + + async def set_webhook(*args, **kwargs): + return True + + async def post_init(app): + # Simply calling app.update_queue.put_nowait(method) in the thread_target doesn't work + # for some reason (probably threading magic), so we use an event from the thread_target + # to put the update into the queue in the main thread. + async def task(app): + await asyncio.get_running_loop().run_in_executor(None, put_update_event.wait) + await app.update_queue.put(method) + + app.create_task(task(app)) + + app = ApplicationBuilder().bot(one_time_bot).post_init(post_init).build() + monkeypatch.setattr(app.bot, "get_updates", get_updates) + monkeypatch.setattr(app.bot, "set_webhook", set_webhook) + monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook) + + events = [] + monkeypatch.setattr( + app.updater, + "stop", + call_after(app.updater.stop, lambda _: events.append("updater.stop")), + ) + monkeypatch.setattr( + app, + "stop", + call_after(app.stop, lambda _: events.append("app.stop")), + ) + monkeypatch.setattr( + app, + "shutdown", + call_after(app.shutdown, lambda _: events.append("app.shutdown")), + ) + + def thread_target(): + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + + time.sleep(0.1) + assertions["called_stop_running_not_set"] = not called_stop_running.is_set() + + put_update_event.set() + time.sleep(0.1) + + assertions["called_stop_running_set"] = called_stop_running.is_set() + + # App should have entered `stop` now but not finished it yet because the callback + # is still running + assertions["updater.stop_event"] = events == ["updater.stop"] + assertions["app.running_False"] = not app.running + + callback_done_event.set() + time.sleep(0.1) + + # Now that the update is fully handled, we expect the full shutdown + assertions["events"] = events == ["updater.stop", "app.stop", "app.shutdown"] + + async def callback(update, context): + context.application.stop_running() + called_stop_running.set() + await asyncio.get_running_loop().run_in_executor(None, callback_done_event.wait) + + app.add_handler(TypeHandler(object, callback)) + + thread = Thread(target=thread_target) + thread.start() + + if method == "polling": + app.run_polling(close_loop=False, drop_pending_updates=True) + else: + ip = "127.0.0.1" + port = randrange(1024, 49152) + + app.run_webhook( + ip_address=ip, + port=port, + url_path="TOKEN", + drop_pending_updates=False, + close_loop=False, + ) + + thread.join() + + assert len(assertions) == 5 + for key, value in assertions.items(): + assert value, f"assertion '{key}' failed!" diff --git a/tests/ext/test_updater.py b/tests/ext/test_updater.py index 803b449a3..68835c176 100644 --- a/tests/ext/test_updater.py +++ b/tests/ext/test_updater.py @@ -214,9 +214,13 @@ class TestUpdater: await updates.put(Update(update_id=2)) async def get_updates(*args, **kwargs): - next_update = await updates.get() - updates.task_done() - return [next_update] + if not updates.empty(): + next_update = await updates.get() + updates.task_done() + return [next_update] + + await asyncio.sleep(0) + return [] orig_del_webhook = updater.bot.delete_webhook @@ -265,6 +269,91 @@ class TestUpdater: assert self.message_count == 4 assert self.received == [1, 2, 3, 4] + async def test_polling_mark_updates_as_read(self, monkeypatch, updater, caplog): + updates = asyncio.Queue() + max_update_id = 3 + for i in range(1, max_update_id + 1): + await updates.put(Update(update_id=i)) + tracking_flag = False + received_kwargs = {} + expected_kwargs = { + "timeout": 0, + "read_timeout": "read_timeout", + "connect_timeout": "connect_timeout", + "write_timeout": "write_timeout", + "pool_timeout": "pool_timeout", + "allowed_updates": "allowed_updates", + } + + async def get_updates(*args, **kwargs): + if tracking_flag: + received_kwargs.update(kwargs) + if not updates.empty(): + next_update = await updates.get() + updates.task_done() + return [next_update] + await asyncio.sleep(0) + return [] + + monkeypatch.setattr(updater.bot, "get_updates", get_updates) + + async with updater: + await updater.start_polling(**expected_kwargs) + await updates.join() + assert not received_kwargs + # Set the flag only now since we want to make sure that the get_updates + # is called one last time by updater.stop() + tracking_flag = True + with caplog.at_level(logging.DEBUG): + await updater.stop() + + # ensure that the last fetched update was still marked as read + assert received_kwargs["offset"] == max_update_id + 1 + # ensure that the correct arguments where passed to the last `get_updates` call + for name, value in expected_kwargs.items(): + assert received_kwargs[name] == value + + assert len(caplog.records) >= 1 + log_found = False + for record in caplog.records: + if not record.getMessage().startswith("Calling `get_updates` one more time"): + continue + + assert record.name == "telegram.ext.Updater" + assert record.levelno == logging.DEBUG + log_found = True + break + + assert log_found + + async def test_polling_mark_updates_as_read_failure(self, monkeypatch, updater, caplog): + async def get_updates(*args, **kwargs): + await asyncio.sleep(0) + return [] + + monkeypatch.setattr(updater.bot, "get_updates", get_updates) + + async with updater: + await updater.start_polling() + # Unfortunately, there is no clean way to test this scenario as it should in fact + # never happen + updater._Updater__polling_cleanup_cb = None + with caplog.at_level(logging.DEBUG): + await updater.stop() + + assert len(caplog.records) >= 1 + log_found = False + for record in caplog.records: + if not record.getMessage().startswith("No polling cleanup callback defined"): + continue + + assert record.name == "telegram.ext.Updater" + assert record.levelno == logging.WARNING + log_found = True + break + + assert log_found + async def test_start_polling_already_running(self, updater): async with updater: await updater.start_polling() @@ -278,6 +367,7 @@ class TestUpdater: async def test_start_polling_get_updates_parameters(self, updater, monkeypatch): update_queue = asyncio.Queue() await update_queue.put(Update(update_id=1)) + on_stop_flag = False expected = { "timeout": 10, @@ -290,6 +380,11 @@ class TestUpdater: } async def get_updates(*args, **kwargs): + if on_stop_flag: + # This is tested in test_polling_mark_updates_as_read + await asyncio.sleep(0) + return [] + for key, value in expected.items(): assert kwargs.pop(key, None) == value @@ -300,17 +395,23 @@ class TestUpdater: if offset is not None and self.message_count != 0: assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter" - update = await update_queue.get() - self.message_count = update.update_id - update_queue.task_done() - return [update] + if not update_queue.empty(): + update = await update_queue.get() + self.message_count = update.update_id + update_queue.task_done() + return [update] + + await asyncio.sleep(0) + return [] monkeypatch.setattr(updater.bot, "get_updates", get_updates) async with updater: await updater.start_polling() await update_queue.join() + on_stop_flag = True await updater.stop() + on_stop_flag = False expected = { "timeout": 42, @@ -332,6 +433,7 @@ class TestUpdater: allowed_updates=["message"], ) await update_queue.join() + on_stop_flag = True await updater.stop() @pytest.mark.parametrize("exception_class", [InvalidToken, TelegramError]) @@ -368,12 +470,16 @@ class TestUpdater: async def test_start_polling_exceptions_and_error_callback( self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog ): + raise_exception = True get_updates_event = asyncio.Event() async def get_updates(*args, **kwargs): # So that the main task has a chance to be called await asyncio.sleep(0) + if not raise_exception: + return [] + get_updates_event.set() raise error @@ -428,6 +534,7 @@ class TestUpdater: and record.name == "telegram.ext.Updater" for record in caplog.records ) + raise_exception = False await updater.stop() async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog): @@ -490,9 +597,13 @@ class TestUpdater: await asyncio.sleep(0.01) raise TypeError("Invalid Data") - next_update = await updates.get() - updates.task_done() - return [next_update] + if not updates.empty(): + next_update = await updates.get() + updates.task_done() + return [next_update] + + await asyncio.sleep(0) + return [] orig_del_webhook = updater.bot.delete_webhook