mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Handle SystemExit
raised in Handlers (#4157)
This commit is contained in:
parent
912fe45d8c
commit
637b8e260b
4 changed files with 276 additions and 105 deletions
|
@ -67,7 +67,9 @@ source_suffix = ".rst"
|
|||
master_doc = "index"
|
||||
|
||||
# Global substitutions
|
||||
rst_prolog = (Path.cwd() / "../substitutions/global.rst").read_text(encoding="utf-8")
|
||||
rst_prolog = ""
|
||||
for file in Path.cwd().glob("../substitutions/*.rst"):
|
||||
rst_prolog += "\n" + file.read_text(encoding="utf-8")
|
||||
|
||||
# -- Extension settings ------------------------------------------------
|
||||
napoleon_use_admonition_for_examples = True
|
||||
|
|
1
docs/substitutions/application.rst
Normal file
1
docs/substitutions/application.rst
Normal file
|
@ -0,0 +1 @@
|
|||
.. |app_run_shutdown| replace:: The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. This also works from within handlers, error handlers and jobs. However, using :meth:`~telegram.ext.Application.stop_running` will give a somewhat cleaner shutdown behavior than manually raising those exceptions. On unix, the app will also shut down on receiving the signals specified by
|
|
@ -365,6 +365,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
self.__update_persistence_event = asyncio.Event()
|
||||
self.__update_persistence_lock = asyncio.Lock()
|
||||
self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit
|
||||
self.__stop_running_marker = asyncio.Event()
|
||||
|
||||
async def __aenter__(self: _AppType) -> _AppType: # noqa: PYI019
|
||||
"""|async_context_manager| :meth:`initializes <initialize>` the App.
|
||||
|
@ -516,6 +517,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
await self._add_ch_to_persistence(handler)
|
||||
|
||||
self._initialized = True
|
||||
self.__stop_running_marker.clear()
|
||||
|
||||
async def _add_ch_to_persistence(self, handler: "ConversationHandler") -> None:
|
||||
self._conversation_handler_conversations.update(
|
||||
|
@ -670,14 +672,26 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
raise RuntimeError("This Application is not running!")
|
||||
|
||||
self._running = False
|
||||
self.__stop_running_marker.clear()
|
||||
_LOGGER.info("Application is stopping. This might take a moment.")
|
||||
|
||||
# Stop listening for new updates and handle all pending ones
|
||||
await self.update_queue.put(_STOP_SIGNAL)
|
||||
_LOGGER.debug("Waiting for update_queue to join")
|
||||
await self.update_queue.join()
|
||||
if self.__update_fetcher_task:
|
||||
await self.__update_fetcher_task
|
||||
if self.__update_fetcher_task.done():
|
||||
try:
|
||||
self.__update_fetcher_task.result()
|
||||
except BaseException as exc:
|
||||
_LOGGER.critical(
|
||||
"Fetching updates was aborted due to %r. Suppressing "
|
||||
"exception to ensure graceful shutdown.",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
await self.update_queue.put(_STOP_SIGNAL)
|
||||
_LOGGER.debug("Waiting for update_queue to join")
|
||||
await self.update_queue.join()
|
||||
await self.__update_fetcher_task
|
||||
_LOGGER.debug("Application stopped fetching of updates.")
|
||||
|
||||
if self._job_queue:
|
||||
|
@ -703,17 +717,36 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
|
||||
:attr:`run_webhook` will still be executed.
|
||||
|
||||
This method can also be called within :meth:`post_init`. This allows for a graceful,
|
||||
early shutdown of the application if some condition is met (e.g., a database connection
|
||||
could not be established).
|
||||
|
||||
Note:
|
||||
If the application is not running, this method does nothing.
|
||||
If the application is not running and this method is not called within
|
||||
:meth:`post_init`, this method does nothing.
|
||||
|
||||
Warning:
|
||||
This method is designed to for use in combination with :meth:`run_polling` or
|
||||
:meth:`run_webhook`. Using this method in combination with a custom logic for starting
|
||||
and stopping the application is not guaranteed to work as expected. Use at your own
|
||||
risk.
|
||||
|
||||
.. versionadded:: 20.5
|
||||
|
||||
.. versionchanged:: NEXT.VERSION
|
||||
Added support for calling within :meth:`post_init`.
|
||||
"""
|
||||
if self.running:
|
||||
# This works because `__run` is using `loop.run_forever()`. If that changes, this
|
||||
# method needs to be adapted.
|
||||
asyncio.get_running_loop().stop()
|
||||
else:
|
||||
_LOGGER.debug("Application is not running, stop_running() does nothing.")
|
||||
self.__stop_running_marker.set()
|
||||
if not self._initialized:
|
||||
_LOGGER.debug(
|
||||
"Application is not running and not initialized. `stop_running()` likely has "
|
||||
"no effect."
|
||||
)
|
||||
|
||||
def run_polling(
|
||||
self,
|
||||
|
@ -733,9 +766,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and
|
||||
a graceful shutdown of the app on exit.
|
||||
|
||||
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised.
|
||||
On unix, the app will also shut down on receiving the signals specified by
|
||||
:paramref:`stop_signals`.
|
||||
|app_run_shutdown| :paramref:`stop_signals`.
|
||||
|
||||
The order of execution by :meth:`run_polling` is roughly as follows:
|
||||
|
||||
|
@ -874,9 +905,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
|
||||
a graceful shutdown of the app on exit.
|
||||
|
||||
The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised.
|
||||
On unix, the app will also shut down on receiving the signals specified by
|
||||
:paramref:`stop_signals`.
|
||||
|app_run_shutdown| :paramref:`stop_signals`.
|
||||
|
||||
If :paramref:`cert`
|
||||
and :paramref:`key` are not provided, the webhook will be started directly on
|
||||
|
@ -1038,6 +1067,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
loop.run_until_complete(self.initialize())
|
||||
if self.post_init:
|
||||
loop.run_until_complete(self.post_init(self))
|
||||
if self.__stop_running_marker.is_set():
|
||||
_LOGGER.info("Application received stop signal via `stop_running`. Shutting down.")
|
||||
return
|
||||
loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling
|
||||
loop.run_until_complete(self.start())
|
||||
loop.run_forever()
|
||||
|
@ -1184,45 +1216,44 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
finally:
|
||||
self._mark_for_persistence_update(update=update)
|
||||
|
||||
async def _update_fetcher(self) -> None:
|
||||
async def __update_fetcher(self) -> None:
|
||||
# Continuously fetch updates from the queue. Exit only once the signal object is found.
|
||||
while True:
|
||||
try:
|
||||
update = await self.update_queue.get()
|
||||
update = await self.update_queue.get()
|
||||
|
||||
if update is _STOP_SIGNAL:
|
||||
_LOGGER.debug("Dropping pending updates")
|
||||
while not self.update_queue.empty():
|
||||
self.update_queue.task_done()
|
||||
if update is _STOP_SIGNAL:
|
||||
# For the _STOP_SIGNAL
|
||||
self.update_queue.task_done()
|
||||
return
|
||||
|
||||
# For the _STOP_SIGNAL
|
||||
self.update_queue.task_done()
|
||||
return
|
||||
_LOGGER.debug("Processing update %s", update)
|
||||
|
||||
_LOGGER.debug("Processing update %s", update)
|
||||
|
||||
if self._update_processor.max_concurrent_updates > 1:
|
||||
# We don't await the below because it has to be run concurrently
|
||||
self.create_task(
|
||||
self.__process_update_wrapper(update),
|
||||
update=update,
|
||||
name=f"Application:{self.bot.id}:process_concurrent_update",
|
||||
)
|
||||
else:
|
||||
await self.__process_update_wrapper(update)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# This may happen if the application is manually run via application.start() and
|
||||
# then a KeyboardInterrupt is sent. We must prevent this loop to die since
|
||||
# application.stop() will wait for it's clean shutdown.
|
||||
_LOGGER.warning(
|
||||
"Fetching updates got a asyncio.CancelledError. Ignoring as this task may only"
|
||||
"be closed via `Application.stop`."
|
||||
if self._update_processor.max_concurrent_updates > 1:
|
||||
# We don't await the below because it has to be run concurrently
|
||||
self.create_task(
|
||||
self.__process_update_wrapper(update),
|
||||
update=update,
|
||||
name=f"Application:{self.bot.id}:process_concurrent_update",
|
||||
)
|
||||
else:
|
||||
await self.__process_update_wrapper(update)
|
||||
|
||||
async def _update_fetcher(self) -> None:
|
||||
try:
|
||||
await self.__update_fetcher()
|
||||
finally:
|
||||
while not self.update_queue.empty():
|
||||
_LOGGER.debug("Dropping pending update: %s", self.update_queue.get_nowait())
|
||||
with contextlib.suppress(ValueError):
|
||||
# Since we're shutting down here, it's not too bad if we call task_done
|
||||
# on an empty queue
|
||||
self.update_queue.task_done()
|
||||
|
||||
async def __process_update_wrapper(self, update: object) -> None:
|
||||
await self._update_processor.process_update(update, self.process_update(update))
|
||||
self.update_queue.task_done()
|
||||
try:
|
||||
await self._update_processor.process_update(update, self.process_update(update))
|
||||
finally:
|
||||
self.update_queue.task_done()
|
||||
|
||||
async def process_update(self, update: object) -> None:
|
||||
"""Processes a single update and marks the update to be updated by the persistence later.
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
"""The integration of persistence into the application is tested in test_basepersistence.
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
|
@ -2083,75 +2084,175 @@ class TestApplication:
|
|||
assert set(self.received.keys()) == set(expected.keys())
|
||||
assert self.received == expected
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
)
|
||||
async def test_cancellation_error_does_not_stop_polling(
|
||||
self, one_time_bot, monkeypatch, caplog
|
||||
@pytest.mark.parametrize("exception", [SystemExit, KeyboardInterrupt])
|
||||
def test_raise_system_exit_keyboard_interrupt_post_init(
|
||||
self, one_time_bot, monkeypatch, exception
|
||||
):
|
||||
"""
|
||||
Ensures that hitting CTRL+C while polling *without* run_polling doesn't kill
|
||||
the update_fetcher loop such that a shutdown is still possible.
|
||||
This test is far from perfect, but it's the closest we can come with sane effort.
|
||||
"""
|
||||
async def post_init(application):
|
||||
raise exception
|
||||
|
||||
called_callbacks = set()
|
||||
|
||||
async def callback(*args, **kwargs):
|
||||
called_callbacks.add(kwargs["name"])
|
||||
|
||||
for cls, method, entry in [
|
||||
(Application, "initialize", "app_initialize"),
|
||||
(Application, "start", "app_start"),
|
||||
(Application, "stop", "app_stop"),
|
||||
(Application, "shutdown", "app_shutdown"),
|
||||
(Updater, "initialize", "updater_initialize"),
|
||||
(Updater, "shutdown", "updater_shutdown"),
|
||||
(Updater, "stop", "updater_stop"),
|
||||
(Updater, "start_polling", "updater_start_polling"),
|
||||
]:
|
||||
|
||||
def after(_, name):
|
||||
called_callbacks.add(name)
|
||||
|
||||
monkeypatch.setattr(
|
||||
cls,
|
||||
method,
|
||||
call_after(getattr(cls, method), functools.partial(after, name=entry)),
|
||||
)
|
||||
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.bot(one_time_bot)
|
||||
.post_init(post_init)
|
||||
.post_stop(functools.partial(callback, name="post_stop"))
|
||||
.post_shutdown(functools.partial(callback, name="post_shutdown"))
|
||||
.build()
|
||||
)
|
||||
|
||||
app.run_polling(close_loop=False)
|
||||
|
||||
# This checks two things:
|
||||
# 1. start/stop are *not* called!
|
||||
# 2. we do have a graceful shutdown
|
||||
assert called_callbacks == {
|
||||
"app_initialize",
|
||||
"updater_initialize",
|
||||
"app_shutdown",
|
||||
"post_stop",
|
||||
"post_shutdown",
|
||||
"updater_shutdown",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("exception", [SystemExit("PTBTest"), KeyboardInterrupt("PTBTest")])
|
||||
@pytest.mark.parametrize("kind", ["handler", "error_handler", "job"])
|
||||
# @pytest.mark.parametrize("block", [True, False])
|
||||
# Testing with block=False would be nice but that doesn't work well with pytest for some reason
|
||||
# in any case, block=False is the simpler behavior since it is roughly similar to what happens
|
||||
# when you hit CTRL+C in the commandline.
|
||||
def test_raise_system_exit_keyboard_jobs_handlers(
|
||||
self, one_time_bot, monkeypatch, exception, kind, caplog
|
||||
):
|
||||
async def queue_and_raise(application):
|
||||
await application.update_queue.put("will_not_be_processed")
|
||||
raise exception
|
||||
|
||||
async def handler_callback(update, context):
|
||||
if kind == "handler":
|
||||
await queue_and_raise(context.application)
|
||||
elif kind == "error_handler":
|
||||
raise TelegramError("Triggering error callback")
|
||||
|
||||
async def error_callback(update, context):
|
||||
await queue_and_raise(context.application)
|
||||
|
||||
async def job_callback(context):
|
||||
await queue_and_raise(context.application)
|
||||
|
||||
async def enqueue_update():
|
||||
await asyncio.sleep(0.5)
|
||||
await app.update_queue.put(1)
|
||||
|
||||
async def post_init(application):
|
||||
if kind == "job":
|
||||
application.job_queue.run_once(when=0.5, callback=job_callback)
|
||||
else:
|
||||
app.create_task(enqueue_update())
|
||||
|
||||
async def update_logger_callback(update, context):
|
||||
context.bot_data.setdefault("processed_updates", set()).add(update)
|
||||
|
||||
called_callbacks = set()
|
||||
|
||||
async def callback(*args, **kwargs):
|
||||
called_callbacks.add(kwargs["name"])
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
return [None]
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(one_time_bot, "get_updates", get_updates)
|
||||
app = ApplicationBuilder().bot(one_time_bot).build()
|
||||
for cls, method, entry in [
|
||||
(Application, "initialize", "app_initialize"),
|
||||
(Application, "start", "app_start"),
|
||||
(Application, "stop", "app_stop"),
|
||||
(Application, "shutdown", "app_shutdown"),
|
||||
(Updater, "initialize", "updater_initialize"),
|
||||
(Updater, "shutdown", "updater_shutdown"),
|
||||
(Updater, "stop", "updater_stop"),
|
||||
(Updater, "start_polling", "updater_start_polling"),
|
||||
]:
|
||||
|
||||
original_get = app.update_queue.get
|
||||
raise_cancelled_error = threading.Event()
|
||||
def after(_, name):
|
||||
called_callbacks.add(name)
|
||||
|
||||
async def get(*arg, **kwargs):
|
||||
await asyncio.sleep(0.05)
|
||||
if raise_cancelled_error.is_set():
|
||||
raise_cancelled_error.clear()
|
||||
raise asyncio.CancelledError("Mocked CancelledError")
|
||||
return await original_get(*arg, **kwargs)
|
||||
monkeypatch.setattr(
|
||||
cls,
|
||||
method,
|
||||
call_after(getattr(cls, method), functools.partial(after, name=entry)),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(app.update_queue, "get", get)
|
||||
|
||||
def thread_target():
|
||||
waited = 0
|
||||
while not app.running:
|
||||
time.sleep(0.05)
|
||||
waited += 0.05
|
||||
if waited > 5:
|
||||
pytest.fail("App apparently won't start")
|
||||
|
||||
time.sleep(0.1)
|
||||
raise_cancelled_error.set()
|
||||
|
||||
async with app:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
thread = Thread(target=thread_target)
|
||||
await app.start()
|
||||
thread.start()
|
||||
assert thread.is_alive()
|
||||
raise_cancelled_error.wait()
|
||||
|
||||
# The exit should have been caught and the app should still be running
|
||||
assert not thread.is_alive()
|
||||
assert app.running
|
||||
|
||||
# Explicit shutdown is required
|
||||
await app.stop()
|
||||
thread.join()
|
||||
|
||||
assert not thread.is_alive()
|
||||
assert not app.running
|
||||
|
||||
# Make sure that we were warned about the necessity of a manual shutdown
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.name == "telegram.ext.Application"
|
||||
assert record.getMessage().startswith(
|
||||
"Fetching updates got a asyncio.CancelledError. Ignoring"
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.bot(one_time_bot)
|
||||
.post_init(post_init)
|
||||
.post_stop(functools.partial(callback, name="post_stop"))
|
||||
.post_shutdown(functools.partial(callback, name="post_shutdown"))
|
||||
.build()
|
||||
)
|
||||
monkeypatch.setattr(app.bot, "get_updates", get_updates)
|
||||
|
||||
app.add_handler(TypeHandler(object, update_logger_callback), group=-10)
|
||||
app.add_handler(TypeHandler(object, handler_callback))
|
||||
app.add_error_handler(error_callback)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
app.run_polling(close_loop=False)
|
||||
|
||||
# This checks that we have a clean shutdown even when the user raises SystemExit
|
||||
# or KeyboardInterrupt in a handler/error handler/job callback
|
||||
assert called_callbacks == {
|
||||
"app_initialize",
|
||||
"app_shutdown",
|
||||
"app_start",
|
||||
"app_stop",
|
||||
"post_shutdown",
|
||||
"post_stop",
|
||||
"updater_initialize",
|
||||
"updater_shutdown",
|
||||
"updater_start_polling",
|
||||
"updater_stop",
|
||||
}
|
||||
|
||||
# These next checks make sure that the update queue is properly cleaned even if there are
|
||||
# still pending updates in the queue
|
||||
# Unfortunately this is apparently extremely hard to get right with jobs, so we're
|
||||
# skipping that case for the sake of simplicity
|
||||
if kind == "job":
|
||||
return
|
||||
|
||||
found = False
|
||||
for record in caplog.records:
|
||||
if record.getMessage() != "Dropping pending update: will_not_be_processed":
|
||||
continue
|
||||
assert record.name == "telegram.ext.Application"
|
||||
assert record.levelno == logging.DEBUG
|
||||
found = True
|
||||
assert found, "`Dropping pending updates` message not found in logs!"
|
||||
assert "will_not_be_processed" not in app.bot_data.get("processed_updates", set())
|
||||
|
||||
def test_run_without_updater(self, one_time_bot):
|
||||
app = ApplicationBuilder().bot(one_time_bot).updater(None).build()
|
||||
|
@ -2311,7 +2412,43 @@ class TestApplication:
|
|||
|
||||
assert len(caplog.records) == 1
|
||||
assert caplog.records[-1].name == "telegram.ext.Application"
|
||||
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.")
|
||||
assert caplog.records[-1].getMessage().endswith("`stop_running()` likely has no effect.")
|
||||
|
||||
def test_stop_running_post_init(self, app, monkeypatch, caplog, one_time_bot):
|
||||
async def post_init(app):
|
||||
app.stop_running()
|
||||
|
||||
called_callbacks = []
|
||||
|
||||
async def callback(*args, **kwargs):
|
||||
called_callbacks.append(kwargs["name"])
|
||||
|
||||
monkeypatch.setattr(Application, "start", functools.partial(callback, name="start"))
|
||||
monkeypatch.setattr(
|
||||
Updater, "start_polling", functools.partial(callback, name="start_polling")
|
||||
)
|
||||
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.bot(one_time_bot)
|
||||
.post_init(post_init)
|
||||
.post_stop(functools.partial(callback, name="post_stop"))
|
||||
.post_shutdown(functools.partial(callback, name="post_shutdown"))
|
||||
.build()
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run_polling(close_loop=False)
|
||||
|
||||
# The important part here is that start(_polling) are *not* called!
|
||||
assert called_callbacks == ["post_stop", "post_shutdown"]
|
||||
|
||||
assert len(caplog.records) == 1
|
||||
assert caplog.records[-1].name == "telegram.ext.Application"
|
||||
assert (
|
||||
"Application received stop signal via `stop_running`"
|
||||
in caplog.records[-1].getMessage()
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("method", ["polling", "webhook"])
|
||||
def test_stop_running(self, one_time_bot, monkeypatch, method):
|
||||
|
|
Loading…
Reference in a new issue