Handle SystemExit raised in Handlers (#4157)

This commit is contained in:
Bibo-Joshi 2024-05-20 15:27:08 +02:00 committed by GitHub
parent 912fe45d8c
commit 637b8e260b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 276 additions and 105 deletions

View file

@ -67,7 +67,9 @@ source_suffix = ".rst"
master_doc = "index"
# Global substitutions
rst_prolog = (Path.cwd() / "../substitutions/global.rst").read_text(encoding="utf-8")
rst_prolog = ""
for file in Path.cwd().glob("../substitutions/*.rst"):
rst_prolog += "\n" + file.read_text(encoding="utf-8")
# -- Extension settings ------------------------------------------------
napoleon_use_admonition_for_examples = True

View file

@ -0,0 +1 @@
.. |app_run_shutdown| replace:: The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. This also works from within handlers, error handlers and jobs. However, using :meth:`~telegram.ext.Application.stop_running` will give a somewhat cleaner shutdown behavior than manually raising those exceptions. On unix, the app will also shut down on receiving the signals specified by

View file

@ -365,6 +365,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
self.__update_persistence_event = asyncio.Event()
self.__update_persistence_lock = asyncio.Lock()
self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit
self.__stop_running_marker = asyncio.Event()
async def __aenter__(self: _AppType) -> _AppType: # noqa: PYI019
"""|async_context_manager| :meth:`initializes <initialize>` the App.
@ -516,6 +517,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
await self._add_ch_to_persistence(handler)
self._initialized = True
self.__stop_running_marker.clear()
async def _add_ch_to_persistence(self, handler: "ConversationHandler") -> None:
self._conversation_handler_conversations.update(
@ -670,13 +672,25 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
raise RuntimeError("This Application is not running!")
self._running = False
self.__stop_running_marker.clear()
_LOGGER.info("Application is stopping. This might take a moment.")
# Stop listening for new updates and handle all pending ones
if self.__update_fetcher_task:
if self.__update_fetcher_task.done():
try:
self.__update_fetcher_task.result()
except BaseException as exc:
_LOGGER.critical(
"Fetching updates was aborted due to %r. Suppressing "
"exception to ensure graceful shutdown.",
exc,
exc_info=True,
)
else:
await self.update_queue.put(_STOP_SIGNAL)
_LOGGER.debug("Waiting for update_queue to join")
await self.update_queue.join()
if self.__update_fetcher_task:
await self.__update_fetcher_task
_LOGGER.debug("Application stopped fetching of updates.")
@ -703,17 +717,36 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
:attr:`run_webhook` will still be executed.
This method can also be called within :meth:`post_init`. This allows for a graceful,
early shutdown of the application if some condition is met (e.g., a database connection
could not be established).
Note:
If the application is not running, this method does nothing.
If the application is not running and this method is not called within
:meth:`post_init`, this method does nothing.
Warning:
This method is designed to for use in combination with :meth:`run_polling` or
:meth:`run_webhook`. Using this method in combination with a custom logic for starting
and stopping the application is not guaranteed to work as expected. Use at your own
risk.
.. versionadded:: 20.5
.. versionchanged:: NEXT.VERSION
Added support for calling within :meth:`post_init`.
"""
if self.running:
# This works because `__run` is using `loop.run_forever()`. If that changes, this
# method needs to be adapted.
asyncio.get_running_loop().stop()
else:
_LOGGER.debug("Application is not running, stop_running() does nothing.")
self.__stop_running_marker.set()
if not self._initialized:
_LOGGER.debug(
"Application is not running and not initialized. `stop_running()` likely has "
"no effect."
)
def run_polling(
self,
@ -733,9 +766,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and
a graceful shutdown of the app on exit.
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised.
On unix, the app will also shut down on receiving the signals specified by
:paramref:`stop_signals`.
|app_run_shutdown| :paramref:`stop_signals`.
The order of execution by :meth:`run_polling` is roughly as follows:
@ -874,9 +905,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
a graceful shutdown of the app on exit.
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised.
On unix, the app will also shut down on receiving the signals specified by
:paramref:`stop_signals`.
|app_run_shutdown| :paramref:`stop_signals`.
If :paramref:`cert`
and :paramref:`key` are not provided, the webhook will be started directly on
@ -1038,6 +1067,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.run_until_complete(self.initialize())
if self.post_init:
loop.run_until_complete(self.post_init(self))
if self.__stop_running_marker.is_set():
_LOGGER.info("Application received stop signal via `stop_running`. Shutting down.")
return
loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling
loop.run_until_complete(self.start())
loop.run_forever()
@ -1184,17 +1216,12 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
finally:
self._mark_for_persistence_update(update=update)
async def _update_fetcher(self) -> None:
async def __update_fetcher(self) -> None:
# Continuously fetch updates from the queue. Exit only once the signal object is found.
while True:
try:
update = await self.update_queue.get()
if update is _STOP_SIGNAL:
_LOGGER.debug("Dropping pending updates")
while not self.update_queue.empty():
self.update_queue.task_done()
# For the _STOP_SIGNAL
self.update_queue.task_done()
return
@ -1211,17 +1238,21 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
else:
await self.__process_update_wrapper(update)
except asyncio.CancelledError:
# This may happen if the application is manually run via application.start() and
# then a KeyboardInterrupt is sent. We must prevent this loop to die since
# application.stop() will wait for it's clean shutdown.
_LOGGER.warning(
"Fetching updates got a asyncio.CancelledError. Ignoring as this task may only"
"be closed via `Application.stop`."
)
async def _update_fetcher(self) -> None:
try:
await self.__update_fetcher()
finally:
while not self.update_queue.empty():
_LOGGER.debug("Dropping pending update: %s", self.update_queue.get_nowait())
with contextlib.suppress(ValueError):
# Since we're shutting down here, it's not too bad if we call task_done
# on an empty queue
self.update_queue.task_done()
async def __process_update_wrapper(self, update: object) -> None:
try:
await self._update_processor.process_update(update, self.process_update(update))
finally:
self.update_queue.task_done()
async def process_update(self, update: object) -> None:

View file

@ -19,6 +19,7 @@
"""The integration of persistence into the application is tested in test_basepersistence.
"""
import asyncio
import functools
import inspect
import logging
import os
@ -2083,76 +2084,176 @@ class TestApplication:
assert set(self.received.keys()) == set(expected.keys())
assert self.received == expected
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
async def test_cancellation_error_does_not_stop_polling(
self, one_time_bot, monkeypatch, caplog
@pytest.mark.parametrize("exception", [SystemExit, KeyboardInterrupt])
def test_raise_system_exit_keyboard_interrupt_post_init(
self, one_time_bot, monkeypatch, exception
):
"""
Ensures that hitting CTRL+C while polling *without* run_polling doesn't kill
the update_fetcher loop such that a shutdown is still possible.
This test is far from perfect, but it's the closest we can come with sane effort.
"""
async def post_init(application):
raise exception
called_callbacks = set()
async def callback(*args, **kwargs):
called_callbacks.add(kwargs["name"])
for cls, method, entry in [
(Application, "initialize", "app_initialize"),
(Application, "start", "app_start"),
(Application, "stop", "app_stop"),
(Application, "shutdown", "app_shutdown"),
(Updater, "initialize", "updater_initialize"),
(Updater, "shutdown", "updater_shutdown"),
(Updater, "stop", "updater_stop"),
(Updater, "start_polling", "updater_start_polling"),
]:
def after(_, name):
called_callbacks.add(name)
monkeypatch.setattr(
cls,
method,
call_after(getattr(cls, method), functools.partial(after, name=entry)),
)
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
app.run_polling(close_loop=False)
# This checks two things:
# 1. start/stop are *not* called!
# 2. we do have a graceful shutdown
assert called_callbacks == {
"app_initialize",
"updater_initialize",
"app_shutdown",
"post_stop",
"post_shutdown",
"updater_shutdown",
}
@pytest.mark.parametrize("exception", [SystemExit("PTBTest"), KeyboardInterrupt("PTBTest")])
@pytest.mark.parametrize("kind", ["handler", "error_handler", "job"])
# @pytest.mark.parametrize("block", [True, False])
# Testing with block=False would be nice but that doesn't work well with pytest for some reason
# in any case, block=False is the simpler behavior since it is roughly similar to what happens
# when you hit CTRL+C in the commandline.
def test_raise_system_exit_keyboard_jobs_handlers(
self, one_time_bot, monkeypatch, exception, kind, caplog
):
async def queue_and_raise(application):
await application.update_queue.put("will_not_be_processed")
raise exception
async def handler_callback(update, context):
if kind == "handler":
await queue_and_raise(context.application)
elif kind == "error_handler":
raise TelegramError("Triggering error callback")
async def error_callback(update, context):
await queue_and_raise(context.application)
async def job_callback(context):
await queue_and_raise(context.application)
async def enqueue_update():
await asyncio.sleep(0.5)
await app.update_queue.put(1)
async def post_init(application):
if kind == "job":
application.job_queue.run_once(when=0.5, callback=job_callback)
else:
app.create_task(enqueue_update())
async def update_logger_callback(update, context):
context.bot_data.setdefault("processed_updates", set()).add(update)
called_callbacks = set()
async def callback(*args, **kwargs):
called_callbacks.add(kwargs["name"])
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return [None]
return []
monkeypatch.setattr(one_time_bot, "get_updates", get_updates)
app = ApplicationBuilder().bot(one_time_bot).build()
for cls, method, entry in [
(Application, "initialize", "app_initialize"),
(Application, "start", "app_start"),
(Application, "stop", "app_stop"),
(Application, "shutdown", "app_shutdown"),
(Updater, "initialize", "updater_initialize"),
(Updater, "shutdown", "updater_shutdown"),
(Updater, "stop", "updater_stop"),
(Updater, "start_polling", "updater_start_polling"),
]:
original_get = app.update_queue.get
raise_cancelled_error = threading.Event()
def after(_, name):
called_callbacks.add(name)
async def get(*arg, **kwargs):
await asyncio.sleep(0.05)
if raise_cancelled_error.is_set():
raise_cancelled_error.clear()
raise asyncio.CancelledError("Mocked CancelledError")
return await original_get(*arg, **kwargs)
monkeypatch.setattr(app.update_queue, "get", get)
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
time.sleep(0.1)
raise_cancelled_error.set()
async with app:
with caplog.at_level(logging.WARNING):
thread = Thread(target=thread_target)
await app.start()
thread.start()
assert thread.is_alive()
raise_cancelled_error.wait()
# The exit should have been caught and the app should still be running
assert not thread.is_alive()
assert app.running
# Explicit shutdown is required
await app.stop()
thread.join()
assert not thread.is_alive()
assert not app.running
# Make sure that we were warned about the necessity of a manual shutdown
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.name == "telegram.ext.Application"
assert record.getMessage().startswith(
"Fetching updates got a asyncio.CancelledError. Ignoring"
monkeypatch.setattr(
cls,
method,
call_after(getattr(cls, method), functools.partial(after, name=entry)),
)
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
monkeypatch.setattr(app.bot, "get_updates", get_updates)
app.add_handler(TypeHandler(object, update_logger_callback), group=-10)
app.add_handler(TypeHandler(object, handler_callback))
app.add_error_handler(error_callback)
with caplog.at_level(logging.DEBUG):
app.run_polling(close_loop=False)
# This checks that we have a clean shutdown even when the user raises SystemExit
# or KeyboardInterrupt in a handler/error handler/job callback
assert called_callbacks == {
"app_initialize",
"app_shutdown",
"app_start",
"app_stop",
"post_shutdown",
"post_stop",
"updater_initialize",
"updater_shutdown",
"updater_start_polling",
"updater_stop",
}
# These next checks make sure that the update queue is properly cleaned even if there are
# still pending updates in the queue
# Unfortunately this is apparently extremely hard to get right with jobs, so we're
# skipping that case for the sake of simplicity
if kind == "job":
return
found = False
for record in caplog.records:
if record.getMessage() != "Dropping pending update: will_not_be_processed":
continue
assert record.name == "telegram.ext.Application"
assert record.levelno == logging.DEBUG
found = True
assert found, "`Dropping pending updates` message not found in logs!"
assert "will_not_be_processed" not in app.bot_data.get("processed_updates", set())
def test_run_without_updater(self, one_time_bot):
app = ApplicationBuilder().bot(one_time_bot).updater(None).build()
@ -2311,7 +2412,43 @@ class TestApplication:
assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application"
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.")
assert caplog.records[-1].getMessage().endswith("`stop_running()` likely has no effect.")
def test_stop_running_post_init(self, app, monkeypatch, caplog, one_time_bot):
async def post_init(app):
app.stop_running()
called_callbacks = []
async def callback(*args, **kwargs):
called_callbacks.append(kwargs["name"])
monkeypatch.setattr(Application, "start", functools.partial(callback, name="start"))
monkeypatch.setattr(
Updater, "start_polling", functools.partial(callback, name="start_polling")
)
app = (
ApplicationBuilder()
.bot(one_time_bot)
.post_init(post_init)
.post_stop(functools.partial(callback, name="post_stop"))
.post_shutdown(functools.partial(callback, name="post_shutdown"))
.build()
)
with caplog.at_level(logging.INFO):
app.run_polling(close_loop=False)
# The important part here is that start(_polling) are *not* called!
assert called_callbacks == ["post_stop", "post_shutdown"]
assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application"
assert (
"Application received stop signal via `stop_running`"
in caplog.records[-1].getMessage()
)
@pytest.mark.parametrize("method", ["polling", "webhook"])
def test_stop_running(self, one_time_bot, monkeypatch, method):