From 637b8e260ba00a238f91ee770850bac2006a3951 Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 20 May 2024 15:27:08 +0200 Subject: [PATCH] Handle `SystemExit` raised in Handlers (#4157) --- docs/source/conf.py | 4 +- docs/substitutions/application.rst | 1 + telegram/ext/_application.py | 117 ++++++++----- tests/ext/test_application.py | 259 ++++++++++++++++++++++------- 4 files changed, 276 insertions(+), 105 deletions(-) create mode 100644 docs/substitutions/application.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 923888059..66dfdcac0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -67,7 +67,9 @@ source_suffix = ".rst" master_doc = "index" # Global substitutions -rst_prolog = (Path.cwd() / "../substitutions/global.rst").read_text(encoding="utf-8") +rst_prolog = "" +for file in Path.cwd().glob("../substitutions/*.rst"): + rst_prolog += "\n" + file.read_text(encoding="utf-8") # -- Extension settings ------------------------------------------------ napoleon_use_admonition_for_examples = True diff --git a/docs/substitutions/application.rst b/docs/substitutions/application.rst new file mode 100644 index 000000000..456433044 --- /dev/null +++ b/docs/substitutions/application.rst @@ -0,0 +1 @@ +.. |app_run_shutdown| replace:: The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. This also works from within handlers, error handlers and jobs. However, using :meth:`~telegram.ext.Application.stop_running` will give a somewhat cleaner shutdown behavior than manually raising those exceptions. On unix, the app will also shut down on receiving the signals specified by \ No newline at end of file diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 0c2884dc9..79467b2b4 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -365,6 +365,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica self.__update_persistence_event = asyncio.Event() self.__update_persistence_lock = asyncio.Lock() self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit + self.__stop_running_marker = asyncio.Event() async def __aenter__(self: _AppType) -> _AppType: # noqa: PYI019 """|async_context_manager| :meth:`initializes ` the App. @@ -516,6 +517,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica await self._add_ch_to_persistence(handler) self._initialized = True + self.__stop_running_marker.clear() async def _add_ch_to_persistence(self, handler: "ConversationHandler") -> None: self._conversation_handler_conversations.update( @@ -670,14 +672,26 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica raise RuntimeError("This Application is not running!") self._running = False + self.__stop_running_marker.clear() _LOGGER.info("Application is stopping. This might take a moment.") # Stop listening for new updates and handle all pending ones - await self.update_queue.put(_STOP_SIGNAL) - _LOGGER.debug("Waiting for update_queue to join") - await self.update_queue.join() if self.__update_fetcher_task: - await self.__update_fetcher_task + if self.__update_fetcher_task.done(): + try: + self.__update_fetcher_task.result() + except BaseException as exc: + _LOGGER.critical( + "Fetching updates was aborted due to %r. Suppressing " + "exception to ensure graceful shutdown.", + exc, + exc_info=True, + ) + else: + await self.update_queue.put(_STOP_SIGNAL) + _LOGGER.debug("Waiting for update_queue to join") + await self.update_queue.join() + await self.__update_fetcher_task _LOGGER.debug("Application stopped fetching of updates.") if self._job_queue: @@ -703,17 +717,36 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica shutdown of the application, i.e. the methods listed in :attr:`run_polling` and :attr:`run_webhook` will still be executed. + This method can also be called within :meth:`post_init`. This allows for a graceful, + early shutdown of the application if some condition is met (e.g., a database connection + could not be established). + Note: - If the application is not running, this method does nothing. + If the application is not running and this method is not called within + :meth:`post_init`, this method does nothing. + + Warning: + This method is designed to for use in combination with :meth:`run_polling` or + :meth:`run_webhook`. Using this method in combination with a custom logic for starting + and stopping the application is not guaranteed to work as expected. Use at your own + risk. .. versionadded:: 20.5 + + .. versionchanged:: NEXT.VERSION + Added support for calling within :meth:`post_init`. """ if self.running: # This works because `__run` is using `loop.run_forever()`. If that changes, this # method needs to be adapted. asyncio.get_running_loop().stop() else: - _LOGGER.debug("Application is not running, stop_running() does nothing.") + self.__stop_running_marker.set() + if not self._initialized: + _LOGGER.debug( + "Application is not running and not initialized. `stop_running()` likely has " + "no effect." + ) def run_polling( self, @@ -733,9 +766,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and a graceful shutdown of the app on exit. - The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. - On unix, the app will also shut down on receiving the signals specified by - :paramref:`stop_signals`. + |app_run_shutdown| :paramref:`stop_signals`. The order of execution by :meth:`run_polling` is roughly as follows: @@ -874,9 +905,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and a graceful shutdown of the app on exit. - The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. - On unix, the app will also shut down on receiving the signals specified by - :paramref:`stop_signals`. + |app_run_shutdown| :paramref:`stop_signals`. If :paramref:`cert` and :paramref:`key` are not provided, the webhook will be started directly on @@ -1038,6 +1067,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica loop.run_until_complete(self.initialize()) if self.post_init: loop.run_until_complete(self.post_init(self)) + if self.__stop_running_marker.is_set(): + _LOGGER.info("Application received stop signal via `stop_running`. Shutting down.") + return loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling loop.run_until_complete(self.start()) loop.run_forever() @@ -1184,45 +1216,44 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica finally: self._mark_for_persistence_update(update=update) - async def _update_fetcher(self) -> None: + async def __update_fetcher(self) -> None: # Continuously fetch updates from the queue. Exit only once the signal object is found. while True: - try: - update = await self.update_queue.get() + update = await self.update_queue.get() - if update is _STOP_SIGNAL: - _LOGGER.debug("Dropping pending updates") - while not self.update_queue.empty(): - self.update_queue.task_done() + if update is _STOP_SIGNAL: + # For the _STOP_SIGNAL + self.update_queue.task_done() + return - # For the _STOP_SIGNAL - self.update_queue.task_done() - return + _LOGGER.debug("Processing update %s", update) - _LOGGER.debug("Processing update %s", update) - - if self._update_processor.max_concurrent_updates > 1: - # We don't await the below because it has to be run concurrently - self.create_task( - self.__process_update_wrapper(update), - update=update, - name=f"Application:{self.bot.id}:process_concurrent_update", - ) - else: - await self.__process_update_wrapper(update) - - except asyncio.CancelledError: - # This may happen if the application is manually run via application.start() and - # then a KeyboardInterrupt is sent. We must prevent this loop to die since - # application.stop() will wait for it's clean shutdown. - _LOGGER.warning( - "Fetching updates got a asyncio.CancelledError. Ignoring as this task may only" - "be closed via `Application.stop`." + if self._update_processor.max_concurrent_updates > 1: + # We don't await the below because it has to be run concurrently + self.create_task( + self.__process_update_wrapper(update), + update=update, + name=f"Application:{self.bot.id}:process_concurrent_update", ) + else: + await self.__process_update_wrapper(update) + + async def _update_fetcher(self) -> None: + try: + await self.__update_fetcher() + finally: + while not self.update_queue.empty(): + _LOGGER.debug("Dropping pending update: %s", self.update_queue.get_nowait()) + with contextlib.suppress(ValueError): + # Since we're shutting down here, it's not too bad if we call task_done + # on an empty queue + self.update_queue.task_done() async def __process_update_wrapper(self, update: object) -> None: - await self._update_processor.process_update(update, self.process_update(update)) - self.update_queue.task_done() + try: + await self._update_processor.process_update(update, self.process_update(update)) + finally: + self.update_queue.task_done() async def process_update(self, update: object) -> None: """Processes a single update and marks the update to be updated by the persistence later. diff --git a/tests/ext/test_application.py b/tests/ext/test_application.py index 78997b2c5..6d1827ae4 100644 --- a/tests/ext/test_application.py +++ b/tests/ext/test_application.py @@ -19,6 +19,7 @@ """The integration of persistence into the application is tested in test_basepersistence. """ import asyncio +import functools import inspect import logging import os @@ -2083,75 +2084,175 @@ class TestApplication: assert set(self.received.keys()) == set(expected.keys()) assert self.received == expected - @pytest.mark.skipif( - platform.system() == "Windows", - reason="Can't send signals without stopping whole process on windows", - ) - async def test_cancellation_error_does_not_stop_polling( - self, one_time_bot, monkeypatch, caplog + @pytest.mark.parametrize("exception", [SystemExit, KeyboardInterrupt]) + def test_raise_system_exit_keyboard_interrupt_post_init( + self, one_time_bot, monkeypatch, exception ): - """ - Ensures that hitting CTRL+C while polling *without* run_polling doesn't kill - the update_fetcher loop such that a shutdown is still possible. - This test is far from perfect, but it's the closest we can come with sane effort. - """ + async def post_init(application): + raise exception + + called_callbacks = set() + + async def callback(*args, **kwargs): + called_callbacks.add(kwargs["name"]) + + for cls, method, entry in [ + (Application, "initialize", "app_initialize"), + (Application, "start", "app_start"), + (Application, "stop", "app_stop"), + (Application, "shutdown", "app_shutdown"), + (Updater, "initialize", "updater_initialize"), + (Updater, "shutdown", "updater_shutdown"), + (Updater, "stop", "updater_stop"), + (Updater, "start_polling", "updater_start_polling"), + ]: + + def after(_, name): + called_callbacks.add(name) + + monkeypatch.setattr( + cls, + method, + call_after(getattr(cls, method), functools.partial(after, name=entry)), + ) + + app = ( + ApplicationBuilder() + .bot(one_time_bot) + .post_init(post_init) + .post_stop(functools.partial(callback, name="post_stop")) + .post_shutdown(functools.partial(callback, name="post_shutdown")) + .build() + ) + + app.run_polling(close_loop=False) + + # This checks two things: + # 1. start/stop are *not* called! + # 2. we do have a graceful shutdown + assert called_callbacks == { + "app_initialize", + "updater_initialize", + "app_shutdown", + "post_stop", + "post_shutdown", + "updater_shutdown", + } + + @pytest.mark.parametrize("exception", [SystemExit("PTBTest"), KeyboardInterrupt("PTBTest")]) + @pytest.mark.parametrize("kind", ["handler", "error_handler", "job"]) + # @pytest.mark.parametrize("block", [True, False]) + # Testing with block=False would be nice but that doesn't work well with pytest for some reason + # in any case, block=False is the simpler behavior since it is roughly similar to what happens + # when you hit CTRL+C in the commandline. + def test_raise_system_exit_keyboard_jobs_handlers( + self, one_time_bot, monkeypatch, exception, kind, caplog + ): + async def queue_and_raise(application): + await application.update_queue.put("will_not_be_processed") + raise exception + + async def handler_callback(update, context): + if kind == "handler": + await queue_and_raise(context.application) + elif kind == "error_handler": + raise TelegramError("Triggering error callback") + + async def error_callback(update, context): + await queue_and_raise(context.application) + + async def job_callback(context): + await queue_and_raise(context.application) + + async def enqueue_update(): + await asyncio.sleep(0.5) + await app.update_queue.put(1) + + async def post_init(application): + if kind == "job": + application.job_queue.run_once(when=0.5, callback=job_callback) + else: + app.create_task(enqueue_update()) + + async def update_logger_callback(update, context): + context.bot_data.setdefault("processed_updates", set()).add(update) + + called_callbacks = set() + + async def callback(*args, **kwargs): + called_callbacks.add(kwargs["name"]) async def get_updates(*args, **kwargs): await asyncio.sleep(0) - return [None] + return [] - monkeypatch.setattr(one_time_bot, "get_updates", get_updates) - app = ApplicationBuilder().bot(one_time_bot).build() + for cls, method, entry in [ + (Application, "initialize", "app_initialize"), + (Application, "start", "app_start"), + (Application, "stop", "app_stop"), + (Application, "shutdown", "app_shutdown"), + (Updater, "initialize", "updater_initialize"), + (Updater, "shutdown", "updater_shutdown"), + (Updater, "stop", "updater_stop"), + (Updater, "start_polling", "updater_start_polling"), + ]: - original_get = app.update_queue.get - raise_cancelled_error = threading.Event() + def after(_, name): + called_callbacks.add(name) - async def get(*arg, **kwargs): - await asyncio.sleep(0.05) - if raise_cancelled_error.is_set(): - raise_cancelled_error.clear() - raise asyncio.CancelledError("Mocked CancelledError") - return await original_get(*arg, **kwargs) + monkeypatch.setattr( + cls, + method, + call_after(getattr(cls, method), functools.partial(after, name=entry)), + ) - monkeypatch.setattr(app.update_queue, "get", get) - - def thread_target(): - waited = 0 - while not app.running: - time.sleep(0.05) - waited += 0.05 - if waited > 5: - pytest.fail("App apparently won't start") - - time.sleep(0.1) - raise_cancelled_error.set() - - async with app: - with caplog.at_level(logging.WARNING): - thread = Thread(target=thread_target) - await app.start() - thread.start() - assert thread.is_alive() - raise_cancelled_error.wait() - - # The exit should have been caught and the app should still be running - assert not thread.is_alive() - assert app.running - - # Explicit shutdown is required - await app.stop() - thread.join() - - assert not thread.is_alive() - assert not app.running - - # Make sure that we were warned about the necessity of a manual shutdown - assert len(caplog.records) == 1 - record = caplog.records[0] - assert record.name == "telegram.ext.Application" - assert record.getMessage().startswith( - "Fetching updates got a asyncio.CancelledError. Ignoring" + app = ( + ApplicationBuilder() + .bot(one_time_bot) + .post_init(post_init) + .post_stop(functools.partial(callback, name="post_stop")) + .post_shutdown(functools.partial(callback, name="post_shutdown")) + .build() ) + monkeypatch.setattr(app.bot, "get_updates", get_updates) + + app.add_handler(TypeHandler(object, update_logger_callback), group=-10) + app.add_handler(TypeHandler(object, handler_callback)) + app.add_error_handler(error_callback) + with caplog.at_level(logging.DEBUG): + app.run_polling(close_loop=False) + + # This checks that we have a clean shutdown even when the user raises SystemExit + # or KeyboardInterrupt in a handler/error handler/job callback + assert called_callbacks == { + "app_initialize", + "app_shutdown", + "app_start", + "app_stop", + "post_shutdown", + "post_stop", + "updater_initialize", + "updater_shutdown", + "updater_start_polling", + "updater_stop", + } + + # These next checks make sure that the update queue is properly cleaned even if there are + # still pending updates in the queue + # Unfortunately this is apparently extremely hard to get right with jobs, so we're + # skipping that case for the sake of simplicity + if kind == "job": + return + + found = False + for record in caplog.records: + if record.getMessage() != "Dropping pending update: will_not_be_processed": + continue + assert record.name == "telegram.ext.Application" + assert record.levelno == logging.DEBUG + found = True + assert found, "`Dropping pending updates` message not found in logs!" + assert "will_not_be_processed" not in app.bot_data.get("processed_updates", set()) def test_run_without_updater(self, one_time_bot): app = ApplicationBuilder().bot(one_time_bot).updater(None).build() @@ -2311,7 +2412,43 @@ class TestApplication: assert len(caplog.records) == 1 assert caplog.records[-1].name == "telegram.ext.Application" - assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.") + assert caplog.records[-1].getMessage().endswith("`stop_running()` likely has no effect.") + + def test_stop_running_post_init(self, app, monkeypatch, caplog, one_time_bot): + async def post_init(app): + app.stop_running() + + called_callbacks = [] + + async def callback(*args, **kwargs): + called_callbacks.append(kwargs["name"]) + + monkeypatch.setattr(Application, "start", functools.partial(callback, name="start")) + monkeypatch.setattr( + Updater, "start_polling", functools.partial(callback, name="start_polling") + ) + + app = ( + ApplicationBuilder() + .bot(one_time_bot) + .post_init(post_init) + .post_stop(functools.partial(callback, name="post_stop")) + .post_shutdown(functools.partial(callback, name="post_shutdown")) + .build() + ) + + with caplog.at_level(logging.INFO): + app.run_polling(close_loop=False) + + # The important part here is that start(_polling) are *not* called! + assert called_callbacks == ["post_stop", "post_shutdown"] + + assert len(caplog.records) == 1 + assert caplog.records[-1].name == "telegram.ext.Application" + assert ( + "Application received stop signal via `stop_running`" + in caplog.records[-1].getMessage() + ) @pytest.mark.parametrize("method", ["polling", "webhook"]) def test_stop_running(self, one_time_bot, monkeypatch, method):