Add Application.stop_running() and Improve Marking Updates as Read on Updater.stop() (#3804)

This commit is contained in:
Bibo-Joshi 2023-08-17 11:50:26 +02:00 committed by GitHub
parent 03f87750d4
commit 5128748092
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 327 additions and 29 deletions

View file

@ -1,7 +1,9 @@
.. tip::
When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this
method is likely not the best choice, as it blocks the event loop until it receives a stop
signal as described above.
Instead, you can manually call the methods listed below to start and shut down the application
and the :attr:`~telegram.ext.Application.updater`.
Keeping the event loop running and listening for a stop signal is then up to you.
* When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this
method is likely not the best choice, as it blocks the event loop until it receives a stop
signal as described above.
Instead, you can manually call the methods listed below to start and shut down the application
and the :attr:`~telegram.ext.Application.updater`.
Keeping the event loop running and listening for a stop signal is then up to you.
* To gracefully stop the execution of this method from within a handler, job or error callback,
use :meth:`~telegram.ext.Application.stop_running`.

View file

@ -653,6 +653,24 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
_LOGGER.info("Application.stop() complete")
def stop_running(self) -> None:
"""This method can be used to stop the execution of :meth:`run_polling` or
:meth:`run_webhook` from within a handler, job or error callback. This allows a graceful
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
:attr:`run_webhook` will still be executed.
Note:
If the application is not running, this method does nothing.
.. versionadded:: NEXT.VERSION
"""
if self.running:
# This works because `__run` is using `loop.run_forever()`. If that changes, this
# method needs to be adapted.
asyncio.get_running_loop().stop()
else:
_LOGGER.debug("Application is not running, stop_running() does nothing.")
def run_polling(
self,
poll_interval: float = 0.0,
@ -939,7 +957,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.run_until_complete(self.start())
loop.run_forever()
except (KeyboardInterrupt, SystemExit):
pass
_LOGGER.debug("Application received stop signal. Shutting down.")
except Exception as exc:
# In case the coroutine wasn't awaited, we don't need to bother the user with a warning
updater_coroutine.close()

View file

@ -24,6 +24,7 @@ from pathlib import Path
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Coroutine,
@ -121,6 +122,7 @@ class Updater(AsyncContextManager["Updater"]):
self._httpd: Optional[WebhookServer] = None
self.__lock = asyncio.Lock()
self.__polling_task: Optional[asyncio.Task] = None
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
@property
def running(self) -> bool:
@ -367,6 +369,28 @@ class Updater(AsyncContextManager["Updater"]):
name="Updater:start_polling:polling_task",
)
# Prepare a cleanup callback to await on _stop_polling
# Calling get_updates one more time with the latest `offset` parameter ensures that
# all updates that where put into the update queue are also marked as "read" to TG,
# so we do not receive them again on the next startup
# We define this here so that we can use the same parameters as in the polling task
async def _get_updates_cleanup() -> None:
_LOGGER.debug(
"Calling `get_updates` one more time to mark all fetched updates as read."
)
await self.bot.get_updates(
offset=self._last_update_id,
# We don't want to do long polling here!
timeout=0,
read_timeout=read_timeout,
connect_timeout=connect_timeout,
write_timeout=write_timeout,
pool_timeout=pool_timeout,
allowed_updates=allowed_updates,
)
self.__polling_cleanup_cb = _get_updates_cleanup
if ready is not None:
ready.set()
@ -748,3 +772,12 @@ class Updater(AsyncContextManager["Updater"]):
# after start_polling(), but lets better be safe than sorry ...
self.__polling_task = None
if self.__polling_cleanup_cb:
await self.__polling_cleanup_cb()
self.__polling_cleanup_cb = None
else:
_LOGGER.warning(
"No polling cleanup callback defined. The last fetched updates may be "
"fetched again on the next polling start."
)

View file

@ -1427,7 +1427,7 @@ class TestApplication:
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
def test_run_polling_basic(self, app, monkeypatch):
def test_run_polling_basic(self, app, monkeypatch, caplog):
exception_event = threading.Event()
update_event = threading.Event()
exception = TelegramError("This is a test error")
@ -1464,6 +1464,9 @@ class TestApplication:
time.sleep(0.05)
assertions["exception_handling"] = self.received == exception.message
# So that the get_updates call on shutdown doesn't fail
exception_event.clear()
os.kill(os.getpid(), signal.SIGINT)
time.sleep(0.1)
@ -1478,13 +1481,20 @@ class TestApplication:
thread = Thread(target=thread_target)
thread.start()
app.run_polling(drop_pending_updates=True, close_loop=False)
thread.join()
with caplog.at_level(logging.DEBUG):
app.run_polling(drop_pending_updates=True, close_loop=False)
thread.join()
assert len(assertions) == 8
for key, value in assertions.items():
assert value, f"assertion '{key}' failed!"
found_log = False
for record in caplog.records:
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
found_log = True
assert found_log
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
@ -1692,7 +1702,7 @@ class TestApplication:
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
def test_run_webhook_basic(self, app, monkeypatch):
def test_run_webhook_basic(self, app, monkeypatch, caplog):
assertions = {}
async def delete_webhook(*args, **kwargs):
@ -1741,19 +1751,26 @@ class TestApplication:
ip = "127.0.0.1"
port = randrange(1024, 49152)
app.run_webhook(
ip_address=ip,
port=port,
url_path="TOKEN",
drop_pending_updates=True,
close_loop=False,
)
thread.join()
with caplog.at_level(logging.DEBUG):
app.run_webhook(
ip_address=ip,
port=port,
url_path="TOKEN",
drop_pending_updates=True,
close_loop=False,
)
thread.join()
assert len(assertions) == 7
for key, value in assertions.items():
assert value, f"assertion '{key}' failed!"
found_log = False
for record in caplog.records:
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
found_log = True
assert found_log
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
@ -2226,3 +2243,120 @@ class TestApplication:
assert received_signals == []
else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]
def test_stop_running_not_running(self, app, caplog):
with caplog.at_level(logging.DEBUG):
app.stop_running()
assert len(caplog.records) == 1
assert caplog.records[-1].name == "telegram.ext.Application"
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.")
@pytest.mark.parametrize("method", ["polling", "webhook"])
def test_stop_running(self, one_time_bot, monkeypatch, method):
# asyncio.Event() seems to be hard to use across different threads (awaiting in main
# thread, setting in another thread), so we use threading.Event() instead.
# This requires the use of run_in_executor, but that's fine.
put_update_event = threading.Event()
callback_done_event = threading.Event()
called_stop_running = threading.Event()
assertions = {}
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return []
async def delete_webhook(*args, **kwargs):
return True
async def set_webhook(*args, **kwargs):
return True
async def post_init(app):
# Simply calling app.update_queue.put_nowait(method) in the thread_target doesn't work
# for some reason (probably threading magic), so we use an event from the thread_target
# to put the update into the queue in the main thread.
async def task(app):
await asyncio.get_running_loop().run_in_executor(None, put_update_event.wait)
await app.update_queue.put(method)
app.create_task(task(app))
app = ApplicationBuilder().bot(one_time_bot).post_init(post_init).build()
monkeypatch.setattr(app.bot, "get_updates", get_updates)
monkeypatch.setattr(app.bot, "set_webhook", set_webhook)
monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook)
events = []
monkeypatch.setattr(
app.updater,
"stop",
call_after(app.updater.stop, lambda _: events.append("updater.stop")),
)
monkeypatch.setattr(
app,
"stop",
call_after(app.stop, lambda _: events.append("app.stop")),
)
monkeypatch.setattr(
app,
"shutdown",
call_after(app.shutdown, lambda _: events.append("app.shutdown")),
)
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
time.sleep(0.1)
assertions["called_stop_running_not_set"] = not called_stop_running.is_set()
put_update_event.set()
time.sleep(0.1)
assertions["called_stop_running_set"] = called_stop_running.is_set()
# App should have entered `stop` now but not finished it yet because the callback
# is still running
assertions["updater.stop_event"] = events == ["updater.stop"]
assertions["app.running_False"] = not app.running
callback_done_event.set()
time.sleep(0.1)
# Now that the update is fully handled, we expect the full shutdown
assertions["events"] = events == ["updater.stop", "app.stop", "app.shutdown"]
async def callback(update, context):
context.application.stop_running()
called_stop_running.set()
await asyncio.get_running_loop().run_in_executor(None, callback_done_event.wait)
app.add_handler(TypeHandler(object, callback))
thread = Thread(target=thread_target)
thread.start()
if method == "polling":
app.run_polling(close_loop=False, drop_pending_updates=True)
else:
ip = "127.0.0.1"
port = randrange(1024, 49152)
app.run_webhook(
ip_address=ip,
port=port,
url_path="TOKEN",
drop_pending_updates=False,
close_loop=False,
)
thread.join()
assert len(assertions) == 5
for key, value in assertions.items():
assert value, f"assertion '{key}' failed!"

View file

@ -214,9 +214,13 @@ class TestUpdater:
await updates.put(Update(update_id=2))
async def get_updates(*args, **kwargs):
next_update = await updates.get()
updates.task_done()
return [next_update]
if not updates.empty():
next_update = await updates.get()
updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
orig_del_webhook = updater.bot.delete_webhook
@ -265,6 +269,91 @@ class TestUpdater:
assert self.message_count == 4
assert self.received == [1, 2, 3, 4]
async def test_polling_mark_updates_as_read(self, monkeypatch, updater, caplog):
updates = asyncio.Queue()
max_update_id = 3
for i in range(1, max_update_id + 1):
await updates.put(Update(update_id=i))
tracking_flag = False
received_kwargs = {}
expected_kwargs = {
"timeout": 0,
"read_timeout": "read_timeout",
"connect_timeout": "connect_timeout",
"write_timeout": "write_timeout",
"pool_timeout": "pool_timeout",
"allowed_updates": "allowed_updates",
}
async def get_updates(*args, **kwargs):
if tracking_flag:
received_kwargs.update(kwargs)
if not updates.empty():
next_update = await updates.get()
updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater:
await updater.start_polling(**expected_kwargs)
await updates.join()
assert not received_kwargs
# Set the flag only now since we want to make sure that the get_updates
# is called one last time by updater.stop()
tracking_flag = True
with caplog.at_level(logging.DEBUG):
await updater.stop()
# ensure that the last fetched update was still marked as read
assert received_kwargs["offset"] == max_update_id + 1
# ensure that the correct arguments where passed to the last `get_updates` call
for name, value in expected_kwargs.items():
assert received_kwargs[name] == value
assert len(caplog.records) >= 1
log_found = False
for record in caplog.records:
if not record.getMessage().startswith("Calling `get_updates` one more time"):
continue
assert record.name == "telegram.ext.Updater"
assert record.levelno == logging.DEBUG
log_found = True
break
assert log_found
async def test_polling_mark_updates_as_read_failure(self, monkeypatch, updater, caplog):
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater:
await updater.start_polling()
# Unfortunately, there is no clean way to test this scenario as it should in fact
# never happen
updater._Updater__polling_cleanup_cb = None
with caplog.at_level(logging.DEBUG):
await updater.stop()
assert len(caplog.records) >= 1
log_found = False
for record in caplog.records:
if not record.getMessage().startswith("No polling cleanup callback defined"):
continue
assert record.name == "telegram.ext.Updater"
assert record.levelno == logging.WARNING
log_found = True
break
assert log_found
async def test_start_polling_already_running(self, updater):
async with updater:
await updater.start_polling()
@ -278,6 +367,7 @@ class TestUpdater:
async def test_start_polling_get_updates_parameters(self, updater, monkeypatch):
update_queue = asyncio.Queue()
await update_queue.put(Update(update_id=1))
on_stop_flag = False
expected = {
"timeout": 10,
@ -290,6 +380,11 @@ class TestUpdater:
}
async def get_updates(*args, **kwargs):
if on_stop_flag:
# This is tested in test_polling_mark_updates_as_read
await asyncio.sleep(0)
return []
for key, value in expected.items():
assert kwargs.pop(key, None) == value
@ -300,17 +395,23 @@ class TestUpdater:
if offset is not None and self.message_count != 0:
assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter"
update = await update_queue.get()
self.message_count = update.update_id
update_queue.task_done()
return [update]
if not update_queue.empty():
update = await update_queue.get()
self.message_count = update.update_id
update_queue.task_done()
return [update]
await asyncio.sleep(0)
return []
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
async with updater:
await updater.start_polling()
await update_queue.join()
on_stop_flag = True
await updater.stop()
on_stop_flag = False
expected = {
"timeout": 42,
@ -332,6 +433,7 @@ class TestUpdater:
allowed_updates=["message"],
)
await update_queue.join()
on_stop_flag = True
await updater.stop()
@pytest.mark.parametrize("exception_class", [InvalidToken, TelegramError])
@ -368,12 +470,16 @@ class TestUpdater:
async def test_start_polling_exceptions_and_error_callback(
self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog
):
raise_exception = True
get_updates_event = asyncio.Event()
async def get_updates(*args, **kwargs):
# So that the main task has a chance to be called
await asyncio.sleep(0)
if not raise_exception:
return []
get_updates_event.set()
raise error
@ -428,6 +534,7 @@ class TestUpdater:
and record.name == "telegram.ext.Updater"
for record in caplog.records
)
raise_exception = False
await updater.stop()
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
@ -490,9 +597,13 @@ class TestUpdater:
await asyncio.sleep(0.01)
raise TypeError("Invalid Data")
next_update = await updates.get()
updates.task_done()
return [next_update]
if not updates.empty():
next_update = await updates.get()
updates.task_done()
return [next_update]
await asyncio.sleep(0)
return []
orig_del_webhook = updater.bot.delete_webhook