Handle SystemExit raised in Handlers (#4157)

This commit is contained in:
Bibo-Joshi 2024-05-20 15:27:08 +02:00 committed by GitHub
parent 912fe45d8c
commit 637b8e260b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 276 additions and 105 deletions

View file

@ -67,7 +67,9 @@ source_suffix = ".rst"
master_doc = "index" master_doc = "index"
# Global substitutions # Global substitutions
rst_prolog = (Path.cwd() / "../substitutions/global.rst").read_text(encoding="utf-8") rst_prolog = ""
for file in Path.cwd().glob("../substitutions/*.rst"):
rst_prolog += "\n" + file.read_text(encoding="utf-8")
# -- Extension settings ------------------------------------------------ # -- Extension settings ------------------------------------------------
napoleon_use_admonition_for_examples = True napoleon_use_admonition_for_examples = True

View file

@ -0,0 +1 @@
.. |app_run_shutdown| replace:: The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. This also works from within handlers, error handlers and jobs. However, using :meth:`~telegram.ext.Application.stop_running` will give a somewhat cleaner shutdown behavior than manually raising those exceptions. On unix, the app will also shut down on receiving the signals specified by

View file

@ -365,6 +365,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
self.__update_persistence_event = asyncio.Event() self.__update_persistence_event = asyncio.Event()
self.__update_persistence_lock = asyncio.Lock() self.__update_persistence_lock = asyncio.Lock()
self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit
self.__stop_running_marker = asyncio.Event()
async def __aenter__(self: _AppType) -> _AppType: # noqa: PYI019 async def __aenter__(self: _AppType) -> _AppType: # noqa: PYI019
"""|async_context_manager| :meth:`initializes <initialize>` the App. """|async_context_manager| :meth:`initializes <initialize>` the App.
@ -516,6 +517,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
await self._add_ch_to_persistence(handler) await self._add_ch_to_persistence(handler)
self._initialized = True self._initialized = True
self.__stop_running_marker.clear()
async def _add_ch_to_persistence(self, handler: "ConversationHandler") -> None: async def _add_ch_to_persistence(self, handler: "ConversationHandler") -> None:
self._conversation_handler_conversations.update( self._conversation_handler_conversations.update(
@ -670,13 +672,25 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
raise RuntimeError("This Application is not running!") raise RuntimeError("This Application is not running!")
self._running = False self._running = False
self.__stop_running_marker.clear()
_LOGGER.info("Application is stopping. This might take a moment.") _LOGGER.info("Application is stopping. This might take a moment.")
# Stop listening for new updates and handle all pending ones # Stop listening for new updates and handle all pending ones
if self.__update_fetcher_task:
if self.__update_fetcher_task.done():
try:
self.__update_fetcher_task.result()
except BaseException as exc:
_LOGGER.critical(
"Fetching updates was aborted due to %r. Suppressing "
"exception to ensure graceful shutdown.",
exc,
exc_info=True,
)
else:
await self.update_queue.put(_STOP_SIGNAL) await self.update_queue.put(_STOP_SIGNAL)
_LOGGER.debug("Waiting for update_queue to join") _LOGGER.debug("Waiting for update_queue to join")
await self.update_queue.join() await self.update_queue.join()
if self.__update_fetcher_task:
await self.__update_fetcher_task await self.__update_fetcher_task
_LOGGER.debug("Application stopped fetching of updates.") _LOGGER.debug("Application stopped fetching of updates.")
@ -703,17 +717,36 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
:attr:`run_webhook` will still be executed. :attr:`run_webhook` will still be executed.
This method can also be called within :meth:`post_init`. This allows for a graceful,
early shutdown of the application if some condition is met (e.g., a database connection
could not be established).
Note: Note:
If the application is not running, this method does nothing. If the application is not running and this method is not called within
:meth:`post_init`, this method does nothing.
Warning:
This method is designed to for use in combination with :meth:`run_polling` or
:meth:`run_webhook`. Using this method in combination with a custom logic for starting
and stopping the application is not guaranteed to work as expected. Use at your own
risk.
.. versionadded:: 20.5 .. versionadded:: 20.5
.. versionchanged:: NEXT.VERSION
Added support for calling within :meth:`post_init`.
""" """
if self.running: if self.running:
# This works because `__run` is using `loop.run_forever()`. If that changes, this # This works because `__run` is using `loop.run_forever()`. If that changes, this
# method needs to be adapted. # method needs to be adapted.
asyncio.get_running_loop().stop() asyncio.get_running_loop().stop()
else: else:
_LOGGER.debug("Application is not running, stop_running() does nothing.") self.__stop_running_marker.set()
if not self._initialized:
_LOGGER.debug(
"Application is not running and not initialized. `stop_running()` likely has "
"no effect."
)
def run_polling( def run_polling(
self, self,
@ -733,9 +766,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and
a graceful shutdown of the app on exit. a graceful shutdown of the app on exit.
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. |app_run_shutdown| :paramref:`stop_signals`.
On unix, the app will also shut down on receiving the signals specified by
:paramref:`stop_signals`.
The order of execution by :meth:`run_polling` is roughly as follows: The order of execution by :meth:`run_polling` is roughly as follows:
@ -874,9 +905,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
a graceful shutdown of the app on exit. a graceful shutdown of the app on exit.
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. |app_run_shutdown| :paramref:`stop_signals`.
On unix, the app will also shut down on receiving the signals specified by
:paramref:`stop_signals`.
If :paramref:`cert` If :paramref:`cert`
and :paramref:`key` are not provided, the webhook will be started directly on and :paramref:`key` are not provided, the webhook will be started directly on
@ -1038,6 +1067,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.run_until_complete(self.initialize()) loop.run_until_complete(self.initialize())
if self.post_init: if self.post_init:
loop.run_until_complete(self.post_init(self)) loop.run_until_complete(self.post_init(self))
if self.__stop_running_marker.is_set():
_LOGGER.info("Application received stop signal via `stop_running`. Shutting down.")
return
loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling
loop.run_until_complete(self.start()) loop.run_until_complete(self.start())
loop.run_forever() loop.run_forever()
@ -1184,17 +1216,12 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
finally: finally:
self._mark_for_persistence_update(update=update) self._mark_for_persistence_update(update=update)
async def _update_fetcher(self) -> None: async def __update_fetcher(self) -> None:
# Continuously fetch updates from the queue. Exit only once the signal object is found. # Continuously fetch updates from the queue. Exit only once the signal object is found.
while True: while True:
try:
update = await self.update_queue.get() update = await self.update_queue.get()
if update is _STOP_SIGNAL: if update is _STOP_SIGNAL:
_LOGGER.debug("Dropping pending updates")
while not self.update_queue.empty():
self.update_queue.task_done()
# For the _STOP_SIGNAL # For the _STOP_SIGNAL
self.update_queue.task_done() self.update_queue.task_done()
return return
@ -1211,17 +1238,21 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
else: else:
await self.__process_update_wrapper(update) await self.__process_update_wrapper(update)
except asyncio.CancelledError: async def _update_fetcher(self) -> None:
# This may happen if the application is manually run via application.start() and try:
# then a KeyboardInterrupt is sent. We must prevent this loop to die since await self.__update_fetcher()
# application.stop() will wait for it's clean shutdown. finally:
_LOGGER.warning( while not self.update_queue.empty():
"Fetching updates got a asyncio.CancelledError. Ignoring as this task may only" _LOGGER.debug("Dropping pending update: %s", self.update_queue.get_nowait())
"be closed via `Application.stop`." with contextlib.suppress(ValueError):
) # Since we're shutting down here, it's not too bad if we call task_done
# on an empty queue
self.update_queue.task_done()
async def __process_update_wrapper(self, update: object) -> None: async def __process_update_wrapper(self, update: object) -> None:
try:
await self._update_processor.process_update(update, self.process_update(update)) await self._update_processor.process_update(update, self.process_update(update))
finally:
self.update_queue.task_done() self.update_queue.task_done()
async def process_update(self, update: object) -> None: async def process_update(self, update: object) -> None:

View file

@ -19,6 +19,7 @@
"""The integration of persistence into the application is tested in test_basepersistence. """The integration of persistence into the application is tested in test_basepersistence.
""" """
import asyncio import asyncio
import functools
import inspect import inspect
import logging import logging
import os import os
@ -2083,76 +2084,176 @@ class TestApplication:
assert set(self.received.keys()) == set(expected.keys()) assert set(self.received.keys()) == set(expected.keys())
assert self.received == expected assert self.received == expected
@pytest.mark.skipif( @pytest.mark.parametrize("exception", [SystemExit, KeyboardInterrupt])
platform.system() == "Windows", def test_raise_system_exit_keyboard_interrupt_post_init(
reason="Can't send signals without stopping whole process on windows", self, one_time_bot, monkeypatch, exception
)
async def test_cancellation_error_does_not_stop_polling(
self, one_time_bot, monkeypatch, caplog
): ):
""" async def post_init(application):
Ensures that hitting CTRL+C while polling *without* run_polling doesn't kill raise exception
the update_fetcher loop such that a shutdown is still possible.
This test is far from perfect, but it's the closest we can come with sane effort. called_callbacks = set()
"""
async def callback(*args, **kwargs):
called_callbacks.add(kwargs["name"])
for cls, method, entry in [
(Application, "initialize", "app_initialize"),
(Application, "start", "app_start"),
(Application, "stop", "app_stop"),
(Application, "shutdown", "app_shutdown"),
(Updater, "initialize", "updater_initialize"),
(Updater, "shutdown", "updater_shutdown"),
(Updater, "stop", "updater_stop"),
(Updater, "start_polling", "updater_start_polling"),
]:
def after(_, name):
called_callbacks.add(name)
monkeypatch.setattr(
cls,
method,
call_after(getattr(cls, method), functools.partial(after, name=entry)),
)
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
app.run_polling(close_loop=False)
# This checks two things:
# 1. start/stop are *not* called!
# 2. we do have a graceful shutdown
assert called_callbacks == {
"app_initialize",
"updater_initialize",
"app_shutdown",
"post_stop",
"post_shutdown",
"updater_shutdown",
}
@pytest.mark.parametrize("exception", [SystemExit("PTBTest"), KeyboardInterrupt("PTBTest")])
@pytest.mark.parametrize("kind", ["handler", "error_handler", "job"])
# @pytest.mark.parametrize("block", [True, False])
# Testing with block=False would be nice but that doesn't work well with pytest for some reason
# in any case, block=False is the simpler behavior since it is roughly similar to what happens
# when you hit CTRL+C in the commandline.
def test_raise_system_exit_keyboard_jobs_handlers(
self, one_time_bot, monkeypatch, exception, kind, caplog
):
async def queue_and_raise(application):
await application.update_queue.put("will_not_be_processed")
raise exception
async def handler_callback(update, context):
if kind == "handler":
await queue_and_raise(context.application)
elif kind == "error_handler":
raise TelegramError("Triggering error callback")
async def error_callback(update, context):
await queue_and_raise(context.application)
async def job_callback(context):
await queue_and_raise(context.application)
async def enqueue_update():
await asyncio.sleep(0.5)
await app.update_queue.put(1)
async def post_init(application):
if kind == "job":
application.job_queue.run_once(when=0.5, callback=job_callback)
else:
app.create_task(enqueue_update())
async def update_logger_callback(update, context):
context.bot_data.setdefault("processed_updates", set()).add(update)
called_callbacks = set()
async def callback(*args, **kwargs):
called_callbacks.add(kwargs["name"])
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
await asyncio.sleep(0) await asyncio.sleep(0)
return [None] return []
monkeypatch.setattr(one_time_bot, "get_updates", get_updates) for cls, method, entry in [
app = ApplicationBuilder().bot(one_time_bot).build() (Application, "initialize", "app_initialize"),
(Application, "start", "app_start"),
(Application, "stop", "app_stop"),
(Application, "shutdown", "app_shutdown"),
(Updater, "initialize", "updater_initialize"),
(Updater, "shutdown", "updater_shutdown"),
(Updater, "stop", "updater_stop"),
(Updater, "start_polling", "updater_start_polling"),
]:
original_get = app.update_queue.get def after(_, name):
raise_cancelled_error = threading.Event() called_callbacks.add(name)
async def get(*arg, **kwargs): monkeypatch.setattr(
await asyncio.sleep(0.05) cls,
if raise_cancelled_error.is_set(): method,
raise_cancelled_error.clear() call_after(getattr(cls, method), functools.partial(after, name=entry)),
raise asyncio.CancelledError("Mocked CancelledError")
return await original_get(*arg, **kwargs)
monkeypatch.setattr(app.update_queue, "get", get)
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
time.sleep(0.1)
raise_cancelled_error.set()
async with app:
with caplog.at_level(logging.WARNING):
thread = Thread(target=thread_target)
await app.start()
thread.start()
assert thread.is_alive()
raise_cancelled_error.wait()
# The exit should have been caught and the app should still be running
assert not thread.is_alive()
assert app.running
# Explicit shutdown is required
await app.stop()
thread.join()
assert not thread.is_alive()
assert not app.running
# Make sure that we were warned about the necessity of a manual shutdown
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.name == "telegram.ext.Application"
assert record.getMessage().startswith(
"Fetching updates got a asyncio.CancelledError. Ignoring"
) )
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
monkeypatch.setattr(app.bot, "get_updates", get_updates)
app.add_handler(TypeHandler(object, update_logger_callback), group=-10)
app.add_handler(TypeHandler(object, handler_callback))
app.add_error_handler(error_callback)
with caplog.at_level(logging.DEBUG):
app.run_polling(close_loop=False)
# This checks that we have a clean shutdown even when the user raises SystemExit
# or KeyboardInterrupt in a handler/error handler/job callback
assert called_callbacks == {
"app_initialize",
"app_shutdown",
"app_start",
"app_stop",
"post_shutdown",
"post_stop",
"updater_initialize",
"updater_shutdown",
"updater_start_polling",
"updater_stop",
}
# These next checks make sure that the update queue is properly cleaned even if there are
# still pending updates in the queue
# Unfortunately this is apparently extremely hard to get right with jobs, so we're
# skipping that case for the sake of simplicity
if kind == "job":
return
found = False
for record in caplog.records:
if record.getMessage() != "Dropping pending update: will_not_be_processed":
continue
assert record.name == "telegram.ext.Application"
assert record.levelno == logging.DEBUG
found = True
assert found, "`Dropping pending updates` message not found in logs!"
assert "will_not_be_processed" not in app.bot_data.get("processed_updates", set())
def test_run_without_updater(self, one_time_bot): def test_run_without_updater(self, one_time_bot):
app = ApplicationBuilder().bot(one_time_bot).updater(None).build() app = ApplicationBuilder().bot(one_time_bot).updater(None).build()
@ -2311,7 +2412,43 @@ class TestApplication:
assert len(caplog.records) == 1 assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application" assert caplog.records[-1].name == "telegram.ext.Application"
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.") assert caplog.records[-1].getMessage().endswith("`stop_running()` likely has no effect.")
def test_stop_running_post_init(self, app, monkeypatch, caplog, one_time_bot):
async def post_init(app):
app.stop_running()
called_callbacks = []
async def callback(*args, **kwargs):
called_callbacks.append(kwargs["name"])
monkeypatch.setattr(Application, "start", functools.partial(callback, name="start"))
monkeypatch.setattr(
Updater, "start_polling", functools.partial(callback, name="start_polling")
)
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
with caplog.at_level(logging.INFO):
app.run_polling(close_loop=False)
# The important part here is that start(_polling) are *not* called!
assert called_callbacks == ["post_stop", "post_shutdown"]
assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application"
assert (
"Application received stop signal via `stop_running`"
in caplog.records[-1].getMessage()
)
@pytest.mark.parametrize("method", ["polling", "webhook"]) @pytest.mark.parametrize("method", ["polling", "webhook"])
def test_stop_running(self, one_time_bot, monkeypatch, method): def test_stop_running(self, one_time_bot, monkeypatch, method):