mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-11-21 22:56:38 +01:00
Add Application.stop_running()
and Improve Marking Updates as Read on Updater.stop()
(#3804)
This commit is contained in:
parent
03f87750d4
commit
5128748092
5 changed files with 327 additions and 29 deletions
|
@ -1,7 +1,9 @@
|
|||
.. tip::
|
||||
When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this
|
||||
* When combining ``python-telegram-bot`` with other :mod:`asyncio` based frameworks, using this
|
||||
method is likely not the best choice, as it blocks the event loop until it receives a stop
|
||||
signal as described above.
|
||||
Instead, you can manually call the methods listed below to start and shut down the application
|
||||
and the :attr:`~telegram.ext.Application.updater`.
|
||||
Keeping the event loop running and listening for a stop signal is then up to you.
|
||||
* To gracefully stop the execution of this method from within a handler, job or error callback,
|
||||
use :meth:`~telegram.ext.Application.stop_running`.
|
|
@ -653,6 +653,24 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
|
||||
_LOGGER.info("Application.stop() complete")
|
||||
|
||||
def stop_running(self) -> None:
|
||||
"""This method can be used to stop the execution of :meth:`run_polling` or
|
||||
:meth:`run_webhook` from within a handler, job or error callback. This allows a graceful
|
||||
shutdown of the application, i.e. the methods listed in :attr:`run_polling` and
|
||||
:attr:`run_webhook` will still be executed.
|
||||
|
||||
Note:
|
||||
If the application is not running, this method does nothing.
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
if self.running:
|
||||
# This works because `__run` is using `loop.run_forever()`. If that changes, this
|
||||
# method needs to be adapted.
|
||||
asyncio.get_running_loop().stop()
|
||||
else:
|
||||
_LOGGER.debug("Application is not running, stop_running() does nothing.")
|
||||
|
||||
def run_polling(
|
||||
self,
|
||||
poll_interval: float = 0.0,
|
||||
|
@ -939,7 +957,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
loop.run_until_complete(self.start())
|
||||
loop.run_forever()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
_LOGGER.debug("Application received stop signal. Shutting down.")
|
||||
except Exception as exc:
|
||||
# In case the coroutine wasn't awaited, we don't need to bother the user with a warning
|
||||
updater_coroutine.close()
|
||||
|
|
|
@ -24,6 +24,7 @@ from pathlib import Path
|
|||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Coroutine,
|
||||
|
@ -121,6 +122,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
self._httpd: Optional[WebhookServer] = None
|
||||
self.__lock = asyncio.Lock()
|
||||
self.__polling_task: Optional[asyncio.Task] = None
|
||||
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
|
@ -367,6 +369,28 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
name="Updater:start_polling:polling_task",
|
||||
)
|
||||
|
||||
# Prepare a cleanup callback to await on _stop_polling
|
||||
# Calling get_updates one more time with the latest `offset` parameter ensures that
|
||||
# all updates that where put into the update queue are also marked as "read" to TG,
|
||||
# so we do not receive them again on the next startup
|
||||
# We define this here so that we can use the same parameters as in the polling task
|
||||
async def _get_updates_cleanup() -> None:
|
||||
_LOGGER.debug(
|
||||
"Calling `get_updates` one more time to mark all fetched updates as read."
|
||||
)
|
||||
await self.bot.get_updates(
|
||||
offset=self._last_update_id,
|
||||
# We don't want to do long polling here!
|
||||
timeout=0,
|
||||
read_timeout=read_timeout,
|
||||
connect_timeout=connect_timeout,
|
||||
write_timeout=write_timeout,
|
||||
pool_timeout=pool_timeout,
|
||||
allowed_updates=allowed_updates,
|
||||
)
|
||||
|
||||
self.__polling_cleanup_cb = _get_updates_cleanup
|
||||
|
||||
if ready is not None:
|
||||
ready.set()
|
||||
|
||||
|
@ -748,3 +772,12 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
# after start_polling(), but lets better be safe than sorry ...
|
||||
|
||||
self.__polling_task = None
|
||||
|
||||
if self.__polling_cleanup_cb:
|
||||
await self.__polling_cleanup_cb()
|
||||
self.__polling_cleanup_cb = None
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"No polling cleanup callback defined. The last fetched updates may be "
|
||||
"fetched again on the next polling start."
|
||||
)
|
||||
|
|
|
@ -1427,7 +1427,7 @@ class TestApplication:
|
|||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
)
|
||||
def test_run_polling_basic(self, app, monkeypatch):
|
||||
def test_run_polling_basic(self, app, monkeypatch, caplog):
|
||||
exception_event = threading.Event()
|
||||
update_event = threading.Event()
|
||||
exception = TelegramError("This is a test error")
|
||||
|
@ -1464,6 +1464,9 @@ class TestApplication:
|
|||
time.sleep(0.05)
|
||||
assertions["exception_handling"] = self.received == exception.message
|
||||
|
||||
# So that the get_updates call on shutdown doesn't fail
|
||||
exception_event.clear()
|
||||
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
@ -1478,6 +1481,7 @@ class TestApplication:
|
|||
|
||||
thread = Thread(target=thread_target)
|
||||
thread.start()
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
app.run_polling(drop_pending_updates=True, close_loop=False)
|
||||
thread.join()
|
||||
|
||||
|
@ -1485,6 +1489,12 @@ class TestApplication:
|
|||
for key, value in assertions.items():
|
||||
assert value, f"assertion '{key}' failed!"
|
||||
|
||||
found_log = False
|
||||
for record in caplog.records:
|
||||
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
|
||||
found_log = True
|
||||
assert found_log
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
|
@ -1692,7 +1702,7 @@ class TestApplication:
|
|||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
)
|
||||
def test_run_webhook_basic(self, app, monkeypatch):
|
||||
def test_run_webhook_basic(self, app, monkeypatch, caplog):
|
||||
assertions = {}
|
||||
|
||||
async def delete_webhook(*args, **kwargs):
|
||||
|
@ -1741,6 +1751,7 @@ class TestApplication:
|
|||
ip = "127.0.0.1"
|
||||
port = randrange(1024, 49152)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
app.run_webhook(
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
|
@ -1754,6 +1765,12 @@ class TestApplication:
|
|||
for key, value in assertions.items():
|
||||
assert value, f"assertion '{key}' failed!"
|
||||
|
||||
found_log = False
|
||||
for record in caplog.records:
|
||||
if "received stop signal" in record.getMessage() and record.levelno == logging.DEBUG:
|
||||
found_log = True
|
||||
assert found_log
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
|
@ -2226,3 +2243,120 @@ class TestApplication:
|
|||
assert received_signals == []
|
||||
else:
|
||||
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]
|
||||
|
||||
def test_stop_running_not_running(self, app, caplog):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
app.stop_running()
|
||||
|
||||
assert len(caplog.records) == 1
|
||||
assert caplog.records[-1].name == "telegram.ext.Application"
|
||||
assert caplog.records[-1].getMessage().endswith("stop_running() does nothing.")
|
||||
|
||||
@pytest.mark.parametrize("method", ["polling", "webhook"])
|
||||
def test_stop_running(self, one_time_bot, monkeypatch, method):
|
||||
# asyncio.Event() seems to be hard to use across different threads (awaiting in main
|
||||
# thread, setting in another thread), so we use threading.Event() instead.
|
||||
# This requires the use of run_in_executor, but that's fine.
|
||||
put_update_event = threading.Event()
|
||||
callback_done_event = threading.Event()
|
||||
called_stop_running = threading.Event()
|
||||
assertions = {}
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
async def delete_webhook(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def set_webhook(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def post_init(app):
|
||||
# Simply calling app.update_queue.put_nowait(method) in the thread_target doesn't work
|
||||
# for some reason (probably threading magic), so we use an event from the thread_target
|
||||
# to put the update into the queue in the main thread.
|
||||
async def task(app):
|
||||
await asyncio.get_running_loop().run_in_executor(None, put_update_event.wait)
|
||||
await app.update_queue.put(method)
|
||||
|
||||
app.create_task(task(app))
|
||||
|
||||
app = ApplicationBuilder().bot(one_time_bot).post_init(post_init).build()
|
||||
monkeypatch.setattr(app.bot, "get_updates", get_updates)
|
||||
monkeypatch.setattr(app.bot, "set_webhook", set_webhook)
|
||||
monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook)
|
||||
|
||||
events = []
|
||||
monkeypatch.setattr(
|
||||
app.updater,
|
||||
"stop",
|
||||
call_after(app.updater.stop, lambda _: events.append("updater.stop")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app,
|
||||
"stop",
|
||||
call_after(app.stop, lambda _: events.append("app.stop")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app,
|
||||
"shutdown",
|
||||
call_after(app.shutdown, lambda _: events.append("app.shutdown")),
|
||||
)
|
||||
|
||||
def thread_target():
|
||||
waited = 0
|
||||
while not app.running:
|
||||
time.sleep(0.05)
|
||||
waited += 0.05
|
||||
if waited > 5:
|
||||
pytest.fail("App apparently won't start")
|
||||
|
||||
time.sleep(0.1)
|
||||
assertions["called_stop_running_not_set"] = not called_stop_running.is_set()
|
||||
|
||||
put_update_event.set()
|
||||
time.sleep(0.1)
|
||||
|
||||
assertions["called_stop_running_set"] = called_stop_running.is_set()
|
||||
|
||||
# App should have entered `stop` now but not finished it yet because the callback
|
||||
# is still running
|
||||
assertions["updater.stop_event"] = events == ["updater.stop"]
|
||||
assertions["app.running_False"] = not app.running
|
||||
|
||||
callback_done_event.set()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now that the update is fully handled, we expect the full shutdown
|
||||
assertions["events"] = events == ["updater.stop", "app.stop", "app.shutdown"]
|
||||
|
||||
async def callback(update, context):
|
||||
context.application.stop_running()
|
||||
called_stop_running.set()
|
||||
await asyncio.get_running_loop().run_in_executor(None, callback_done_event.wait)
|
||||
|
||||
app.add_handler(TypeHandler(object, callback))
|
||||
|
||||
thread = Thread(target=thread_target)
|
||||
thread.start()
|
||||
|
||||
if method == "polling":
|
||||
app.run_polling(close_loop=False, drop_pending_updates=True)
|
||||
else:
|
||||
ip = "127.0.0.1"
|
||||
port = randrange(1024, 49152)
|
||||
|
||||
app.run_webhook(
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
drop_pending_updates=False,
|
||||
close_loop=False,
|
||||
)
|
||||
|
||||
thread.join()
|
||||
|
||||
assert len(assertions) == 5
|
||||
for key, value in assertions.items():
|
||||
assert value, f"assertion '{key}' failed!"
|
||||
|
|
|
@ -214,10 +214,14 @@ class TestUpdater:
|
|||
await updates.put(Update(update_id=2))
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
if not updates.empty():
|
||||
next_update = await updates.get()
|
||||
updates.task_done()
|
||||
return [next_update]
|
||||
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
orig_del_webhook = updater.bot.delete_webhook
|
||||
|
||||
async def delete_webhook(*args, **kwargs):
|
||||
|
@ -265,6 +269,91 @@ class TestUpdater:
|
|||
assert self.message_count == 4
|
||||
assert self.received == [1, 2, 3, 4]
|
||||
|
||||
async def test_polling_mark_updates_as_read(self, monkeypatch, updater, caplog):
|
||||
updates = asyncio.Queue()
|
||||
max_update_id = 3
|
||||
for i in range(1, max_update_id + 1):
|
||||
await updates.put(Update(update_id=i))
|
||||
tracking_flag = False
|
||||
received_kwargs = {}
|
||||
expected_kwargs = {
|
||||
"timeout": 0,
|
||||
"read_timeout": "read_timeout",
|
||||
"connect_timeout": "connect_timeout",
|
||||
"write_timeout": "write_timeout",
|
||||
"pool_timeout": "pool_timeout",
|
||||
"allowed_updates": "allowed_updates",
|
||||
}
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
if tracking_flag:
|
||||
received_kwargs.update(kwargs)
|
||||
if not updates.empty():
|
||||
next_update = await updates.get()
|
||||
updates.task_done()
|
||||
return [next_update]
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
|
||||
|
||||
async with updater:
|
||||
await updater.start_polling(**expected_kwargs)
|
||||
await updates.join()
|
||||
assert not received_kwargs
|
||||
# Set the flag only now since we want to make sure that the get_updates
|
||||
# is called one last time by updater.stop()
|
||||
tracking_flag = True
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await updater.stop()
|
||||
|
||||
# ensure that the last fetched update was still marked as read
|
||||
assert received_kwargs["offset"] == max_update_id + 1
|
||||
# ensure that the correct arguments where passed to the last `get_updates` call
|
||||
for name, value in expected_kwargs.items():
|
||||
assert received_kwargs[name] == value
|
||||
|
||||
assert len(caplog.records) >= 1
|
||||
log_found = False
|
||||
for record in caplog.records:
|
||||
if not record.getMessage().startswith("Calling `get_updates` one more time"):
|
||||
continue
|
||||
|
||||
assert record.name == "telegram.ext.Updater"
|
||||
assert record.levelno == logging.DEBUG
|
||||
log_found = True
|
||||
break
|
||||
|
||||
assert log_found
|
||||
|
||||
async def test_polling_mark_updates_as_read_failure(self, monkeypatch, updater, caplog):
|
||||
async def get_updates(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
|
||||
|
||||
async with updater:
|
||||
await updater.start_polling()
|
||||
# Unfortunately, there is no clean way to test this scenario as it should in fact
|
||||
# never happen
|
||||
updater._Updater__polling_cleanup_cb = None
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await updater.stop()
|
||||
|
||||
assert len(caplog.records) >= 1
|
||||
log_found = False
|
||||
for record in caplog.records:
|
||||
if not record.getMessage().startswith("No polling cleanup callback defined"):
|
||||
continue
|
||||
|
||||
assert record.name == "telegram.ext.Updater"
|
||||
assert record.levelno == logging.WARNING
|
||||
log_found = True
|
||||
break
|
||||
|
||||
assert log_found
|
||||
|
||||
async def test_start_polling_already_running(self, updater):
|
||||
async with updater:
|
||||
await updater.start_polling()
|
||||
|
@ -278,6 +367,7 @@ class TestUpdater:
|
|||
async def test_start_polling_get_updates_parameters(self, updater, monkeypatch):
|
||||
update_queue = asyncio.Queue()
|
||||
await update_queue.put(Update(update_id=1))
|
||||
on_stop_flag = False
|
||||
|
||||
expected = {
|
||||
"timeout": 10,
|
||||
|
@ -290,6 +380,11 @@ class TestUpdater:
|
|||
}
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
if on_stop_flag:
|
||||
# This is tested in test_polling_mark_updates_as_read
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
for key, value in expected.items():
|
||||
assert kwargs.pop(key, None) == value
|
||||
|
||||
|
@ -300,17 +395,23 @@ class TestUpdater:
|
|||
if offset is not None and self.message_count != 0:
|
||||
assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter"
|
||||
|
||||
if not update_queue.empty():
|
||||
update = await update_queue.get()
|
||||
self.message_count = update.update_id
|
||||
update_queue.task_done()
|
||||
return [update]
|
||||
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
|
||||
|
||||
async with updater:
|
||||
await updater.start_polling()
|
||||
await update_queue.join()
|
||||
on_stop_flag = True
|
||||
await updater.stop()
|
||||
on_stop_flag = False
|
||||
|
||||
expected = {
|
||||
"timeout": 42,
|
||||
|
@ -332,6 +433,7 @@ class TestUpdater:
|
|||
allowed_updates=["message"],
|
||||
)
|
||||
await update_queue.join()
|
||||
on_stop_flag = True
|
||||
await updater.stop()
|
||||
|
||||
@pytest.mark.parametrize("exception_class", [InvalidToken, TelegramError])
|
||||
|
@ -368,12 +470,16 @@ class TestUpdater:
|
|||
async def test_start_polling_exceptions_and_error_callback(
|
||||
self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog
|
||||
):
|
||||
raise_exception = True
|
||||
get_updates_event = asyncio.Event()
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
# So that the main task has a chance to be called
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if not raise_exception:
|
||||
return []
|
||||
|
||||
get_updates_event.set()
|
||||
raise error
|
||||
|
||||
|
@ -428,6 +534,7 @@ class TestUpdater:
|
|||
and record.name == "telegram.ext.Updater"
|
||||
for record in caplog.records
|
||||
)
|
||||
raise_exception = False
|
||||
await updater.stop()
|
||||
|
||||
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
|
||||
|
@ -490,10 +597,14 @@ class TestUpdater:
|
|||
await asyncio.sleep(0.01)
|
||||
raise TypeError("Invalid Data")
|
||||
|
||||
if not updates.empty():
|
||||
next_update = await updates.get()
|
||||
updates.task_done()
|
||||
return [next_update]
|
||||
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
orig_del_webhook = updater.bot.delete_webhook
|
||||
|
||||
async def delete_webhook(*args, **kwargs):
|
||||
|
|
Loading…
Reference in a new issue