Add Application.stop_running() and Improve Marking Updates as Read on Updater.stop() (#3804)

This commit is contained in:
Bibo-Joshi 2023-08-17 11:50:26 +02:00 committed by GitHub
parent 03f87750d4
commit 5128748092
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 327 additions and 29 deletions

View file

@ -1,7 +1,9 @@
.. tip:: .. tip::
When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this * When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this
method is likely not the best choice, as it blocks the event loop until it receives a stop method is likely not the best choice, as it blocks the event loop until it receives a stop
signal as described above. signal as described above.
Instead, you can manually call the methods listed below to start and shut down the application Instead, you can manually call the methods listed below to start and shut down the application
and the :attr:`~telegram.ext.Application.updater`. and the :attr:`~telegram.ext.Application.updater`.
Keeping the event loop running and listening for a stop signal is then up to you. Keeping the event loop running and listening for a stop signal is then up to you.
* To gracefully stop the execution of this method from within a handler, job or error callback,
use :meth:`~telegram.ext.Application.stop_running`.

View file

@ -653,6 +653,24 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
_LOGGER.info("Application.stop() complete") _LOGGER.info("Application.stop() complete")
def stop_running(self) -> None:
"""This method can be used to stop the execution of :meth:`run_polling` or
:meth:`run_webhook` from within a handler, job or error callback. This allows a graceful
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
:attr:`run_webhook` will still be executed.
Note:
If the application is not running, this method does nothing.
.. versionadded:: NEXT.VERSION
"""
if self.running:
# This works because `__run` is using `loop.run_forever()`. If that changes, this
# method needs to be adapted.
asyncio.get_running_loop().stop()
else:
_LOGGER.debug("Application is not running, stop_running() does nothing.")
def run_polling( def run_polling(
self, self,
poll_interval: float = 0.0, poll_interval: float = 0.0,
@ -939,7 +957,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.run_until_complete(self.start()) loop.run_until_complete(self.start())
loop.run_forever() loop.run_forever()
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
pass _LOGGER.debug("Application received stop signal. Shutting down.")
except Exception as exc: except Exception as exc:
# In case the coroutine wasn't awaited, we don't need to bother the user with a warning # In case the coroutine wasn't awaited, we don't need to bother the user with a warning
updater_coroutine.close() updater_coroutine.close()

View file

@ -24,6 +24,7 @@ from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
AsyncContextManager, AsyncContextManager,
Callable, Callable,
Coroutine, Coroutine,
@ -121,6 +122,7 @@ class Updater(AsyncContextManager["Updater"]):
self._httpd: Optional[WebhookServer] = None self._httpd: Optional[WebhookServer] = None
self.__lock = asyncio.Lock() self.__lock = asyncio.Lock()
self.__polling_task: Optional[asyncio.Task] = None self.__polling_task: Optional[asyncio.Task] = None
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
@property @property
def running(self) -> bool: def running(self) -> bool:
@ -367,6 +369,28 @@ class Updater(AsyncContextManager["Updater"]):
name="Updater:start_polling:polling_task", name="Updater:start_polling:polling_task",
) )
# Prepare a cleanup callback to await on _stop_polling
# Calling get_updates one more time with the latest `offset` parameter ensures that
# all updates that where put into the update queue are also marked as "read" to TG,
# so we do not receive them again on the next startup
# We define this here so that we can use the same parameters as in the polling task
async def _get_updates_cleanup() -> None:
_LOGGER.debug(
"Calling `get_updates` one more time to mark all fetched updates as read."
)
await self.bot.get_updates(
offset=self._last_update_id,
# We don't want to do long polling here!
timeout=0,
read_timeout=read_timeout,
connect_timeout=connect_timeout,
write_timeout=write_timeout,
pool_timeout=pool_timeout,
allowed_updates=allowed_updates,
)
self.__polling_cleanup_cb = _get_updates_cleanup
if ready is not None: if ready is not None:
ready.set() ready.set()
@ -748,3 +772,12 @@ class Updater(AsyncContextManager["Updater"]):
# after start_polling(), but lets better be safe than sorry ... # after start_polling(), but lets better be safe than sorry ...
self.__polling_task = None self.__polling_task = None
if self.__polling_cleanup_cb:
await self.__polling_cleanup_cb()
self.__polling_cleanup_cb = None
else:
_LOGGER.warning(
"No polling cleanup callback defined. The last fetched updates may be "
"fetched again on the next polling start."
)

View file

@ -1427,7 +1427,7 @@ class TestApplication:
platform.system() == "Windows", platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows", reason="Can't send signals without stopping whole process on windows",
) )
def test_run_polling_basic(self, app, monkeypatch): def test_run_polling_basic(self, app, monkeypatch, caplog):
exception_event = threading.Event() exception_event = threading.Event()
update_event = threading.Event() update_event = threading.Event()
exception = TelegramError("This is a test error") exception = TelegramError("This is a test error")
@ -1464,6 +1464,9 @@ class TestApplication:
time.sleep(0.05) time.sleep(0.05)
assertions["exception_handling"] = self.received == exception.message assertions["exception_handling"] = self.received == exception.message
# So that the get_updates call on shutdown doesn't fail
exception_event.clear()
os.kill(os.getpid(), signal.SIGINT) os.kill(os.getpid(), signal.SIGINT)
time.sleep(0.1) time.sleep(0.1)
@ -1478,13 +1481,20 @@ class TestApplication:
thread = Thread(target=thread_target) thread = Thread(target=thread_target)
thread.start() thread.start()
app.run_polling(drop_pending_updates=True, close_loop=False) with caplog.at_level(logging.DEBUG):
thread.join() app.run_polling(drop_pending_updates=True, close_loop=False)
thread.join()
assert len(assertions) == 8 assert len(assertions) == 8
for key, value in assertions.items(): for key, value in assertions.items():
assert value, f"assertion '{key}' failed!" assert value, f"assertion '{key}' failed!"
found_log = False
for record in caplog.records:
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
found_log = True
assert found_log
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows", reason="Can't send signals without stopping whole process on windows",
@ -1692,7 +1702,7 @@ class TestApplication:
platform.system() == "Windows", platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows", reason="Can't send signals without stopping whole process on windows",
) )
def test_run_webhook_basic(self, app, monkeypatch): def test_run_webhook_basic(self, app, monkeypatch, caplog):
assertions = {} assertions = {}
async def delete_webhook(*args, **kwargs): async def delete_webhook(*args, **kwargs):
@ -1741,19 +1751,26 @@ class TestApplication:
ip = "127.0.0.1" ip = "127.0.0.1"
port = randrange(1024, 49152) port = randrange(1024, 49152)
app.run_webhook( with caplog.at_level(logging.DEBUG):
ip_address=ip, app.run_webhook(
port=port, ip_address=ip,
url_path="TOKEN", port=port,
drop_pending_updates=True, url_path="TOKEN",
close_loop=False, drop_pending_updates=True,
) close_loop=False,
thread.join() )
thread.join()
assert len(assertions) == 7 assert len(assertions) == 7
for key, value in assertions.items(): for key, value in assertions.items():
assert value, f"assertion '{key}' failed!" assert value, f"assertion '{key}' failed!"
found_log = False
for record in caplog.records:
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
found_log = True
assert found_log
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows", reason="Can't send signals without stopping whole process on windows",
@ -2226,3 +2243,120 @@ class TestApplication:
assert received_signals == [] assert received_signals == []
else: else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]
def test_stop_running_not_running(self, app, caplog):
with caplog.at_level(logging.DEBUG):
app.stop_running()
assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application"
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.")
@pytest.mark.parametrize("method", ["polling", "webhook"])
def test_stop_running(self, one_time_bot, monkeypatch, method):
# asyncio.Event() seems to be hard to use across different threads (awaiting in main
# thread, setting in another thread), so we use threading.Event() instead.
# This requires the use of run_in_executor, but that's fine.
put_update_event = threading.Event()
callback_done_event = threading.Event()
called_stop_running = threading.Event()
assertions = {}
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return []
async def delete_webhook(*args, **kwargs):
return True
async def set_webhook(*args, **kwargs):
return True
async def post_init(app):
# Simply calling app.update_queue.put_nowait(method) in the thread_target doesn't work
# for some reason (probably threading magic), so we use an event from the thread_target
# to put the update into the queue in the main thread.
async def task(app):
await asyncio.get_running_loop().run_in_executor(None, put_update_event.wait)
await app.update_queue.put(method)
app.create_task(task(app))
app = ApplicationBuilder().bot(one_time_bot).post_init(post_init).build()
monkeypatch.setattr(app.bot, "get_updates", get_updates)
monkeypatch.setattr(app.bot, "set_webhook", set_webhook)
monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook)
events = []
monkeypatch.setattr(
app.updater,
"stop",
call_after(app.updater.stop, lambda _: events.append("updater.stop")),
)
monkeypatch.setattr(
app,
"stop",
call_after(app.stop, lambda _: events.append("app.stop")),
)
monkeypatch.setattr(
app,
"shutdown",
call_after(app.shutdown, lambda _: events.append("app.shutdown")),
)
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
time.sleep(0.1)
assertions["called_stop_running_not_set"] = not called_stop_running.is_set()
put_update_event.set()
time.sleep(0.1)
assertions["called_stop_running_set"] = called_stop_running.is_set()
# App should have entered `stop` now but not finished it yet because the callback
# is still running
assertions["updater.stop_event"] = events == ["updater.stop"]
assertions["app.running_False"] = not app.running
callback_done_event.set()
time.sleep(0.1)
# Now that the update is fully handled, we expect the full shutdown
assertions["events"] = events == ["updater.stop", "app.stop", "app.shutdown"]
async def callback(update, context):
context.application.stop_running()
called_stop_running.set()
await asyncio.get_running_loop().run_in_executor(None, callback_done_event.wait)
app.add_handler(TypeHandler(object, callback))
thread = Thread(target=thread_target)
thread.start()
if method == "polling":
app.run_polling(close_loop=False, drop_pending_updates=True)
else:
ip = "127.0.0.1"
port = randrange(1024, 49152)
app.run_webhook(
ip_address=ip,
port=port,
url_path="TOKEN",
drop_pending_updates=False,
close_loop=False,
)
thread.join()
assert len(assertions) == 5
for key, value in assertions.items():
assert value, f"assertion '{key}' failed!"

View file

@ -214,9 +214,13 @@ class TestUpdater:
await updates.put(Update(update_id=2)) await updates.put(Update(update_id=2))
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
next_update = await updates.get() if not updates.empty():
updates.task_done() next_update = await updates.get()
return [next_update] updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
orig_del_webhook = updater.bot.delete_webhook orig_del_webhook = updater.bot.delete_webhook
@ -265,6 +269,91 @@ class TestUpdater:
assert self.message_count == 4 assert self.message_count == 4
assert self.received == [1, 2, 3, 4] assert self.received == [1, 2, 3, 4]
async def test_polling_mark_updates_as_read(self, monkeypatch, updater, caplog):
updates = asyncio.Queue()
max_update_id = 3
for i in range(1, max_update_id + 1):
await updates.put(Update(update_id=i))
tracking_flag = False
received_kwargs = {}
expected_kwargs = {
"timeout": 0,
"read_timeout": "read_timeout",
"connect_timeout": "connect_timeout",
"write_timeout": "write_timeout",
"pool_timeout": "pool_timeout",
"allowed_updates": "allowed_updates",
}
async def get_updates(*args, **kwargs):
if tracking_flag:
received_kwargs.update(kwargs)
if not updates.empty():
next_update = await updates.get()
updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater:
await updater.start_polling(**expected_kwargs)
await updates.join()
assert not received_kwargs
# Set the flag only now since we want to make sure that the get_updates
# is called one last time by updater.stop()
tracking_flag = True
with caplog.at_level(logging.DEBUG):
await updater.stop()
# ensure that the last fetched update was still marked as read
assert received_kwargs["offset"] == max_update_id + 1
# ensure that the correct arguments where passed to the last `get_updates` call
for name, value in expected_kwargs.items():
assert received_kwargs[name] == value
assert len(caplog.records) >= 1
log_found = False
for record in caplog.records:
if not record.getMessage().startswith("Calling `get_updates` one more time"):
continue
assert record.name == "telegram.ext.Updater"
assert record.levelno == logging.DEBUG
log_found = True
break
assert log_found
async def test_polling_mark_updates_as_read_failure(self, monkeypatch, updater, caplog):
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater:
await updater.start_polling()
# Unfortunately, there is no clean way to test this scenario as it should in fact
# never happen
updater._Updater__polling_cleanup_cb = None
with caplog.at_level(logging.DEBUG):
await updater.stop()
assert len(caplog.records) >= 1
log_found = False
for record in caplog.records:
if not record.getMessage().startswith("No polling cleanup callback defined"):
continue
assert record.name == "telegram.ext.Updater"
assert record.levelno == logging.WARNING
log_found = True
break
assert log_found
async def test_start_polling_already_running(self, updater): async def test_start_polling_already_running(self, updater):
async with updater: async with updater:
await updater.start_polling() await updater.start_polling()
@ -278,6 +367,7 @@ class TestUpdater:
async def test_start_polling_get_updates_parameters(self, updater, monkeypatch): async def test_start_polling_get_updates_parameters(self, updater, monkeypatch):
update_queue = asyncio.Queue() update_queue = asyncio.Queue()
await update_queue.put(Update(update_id=1)) await update_queue.put(Update(update_id=1))
on_stop_flag = False
expected = { expected = {
"timeout": 10, "timeout": 10,
@ -290,6 +380,11 @@ class TestUpdater:
} }
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
if on_stop_flag:
# This is tested in test_polling_mark_updates_as_read
await asyncio.sleep(0)
return []
for key, value in expected.items(): for key, value in expected.items():
assert kwargs.pop(key, None) == value assert kwargs.pop(key, None) == value
@ -300,17 +395,23 @@ class TestUpdater:
if offset is not None and self.message_count != 0: if offset is not None and self.message_count != 0:
assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter" assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter"
update = await update_queue.get() if not update_queue.empty():
self.message_count = update.update_id update = await update_queue.get()
update_queue.task_done() self.message_count = update.update_id
return [update] update_queue.task_done()
return [update]
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates) monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater: async with updater:
await updater.start_polling() await updater.start_polling()
await update_queue.join() await update_queue.join()
on_stop_flag = True
await updater.stop() await updater.stop()
on_stop_flag = False
expected = { expected = {
"timeout": 42, "timeout": 42,
@ -332,6 +433,7 @@ class TestUpdater:
allowed_updates=["message"], allowed_updates=["message"],
) )
await update_queue.join() await update_queue.join()
on_stop_flag = True
await updater.stop() await updater.stop()
@pytest.mark.parametrize("exception_class", [InvalidToken, TelegramError]) @pytest.mark.parametrize("exception_class", [InvalidToken, TelegramError])
@ -368,12 +470,16 @@ class TestUpdater:
async def test_start_polling_exceptions_and_error_callback( async def test_start_polling_exceptions_and_error_callback(
self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog
): ):
raise_exception = True
get_updates_event = asyncio.Event() get_updates_event = asyncio.Event()
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
# So that the main task has a chance to be called # So that the main task has a chance to be called
await asyncio.sleep(0) await asyncio.sleep(0)
if not raise_exception:
return []
get_updates_event.set() get_updates_event.set()
raise error raise error
@ -428,6 +534,7 @@ class TestUpdater:
and record.name == "telegram.ext.Updater" and record.name == "telegram.ext.Updater"
for record in caplog.records for record in caplog.records
) )
raise_exception = False
await updater.stop() await updater.stop()
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog): async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
@ -490,9 +597,13 @@ class TestUpdater:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
raise TypeError("Invalid Data") raise TypeError("Invalid Data")
next_update = await updates.get() if not updates.empty():
updates.task_done() next_update = await updates.get()
return [next_update] updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
orig_del_webhook = updater.bot.delete_webhook orig_del_webhook = updater.bot.delete_webhook