Make Updater.stop Independent of CancelledError (#4126)

This commit is contained in:
Bibo-Joshi 2024-03-03 19:22:42 +01:00 committed by GitHub
parent bd9b0bd126
commit 20e0f87f6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 75 additions and 37 deletions

View file

@ -104,6 +104,7 @@ class Updater(AsyncContextManager["Updater"]):
"__lock",
"__polling_cleanup_cb",
"__polling_task",
"__polling_task_stop_event",
"_httpd",
"_initialized",
"_last_update_id",
@ -126,6 +127,7 @@ class Updater(AsyncContextManager["Updater"]):
self._httpd: Optional[WebhookServer] = None
self.__lock = asyncio.Lock()
self.__polling_task: Optional[asyncio.Task] = None
self.__polling_task_stop_event: asyncio.Event = asyncio.Event()
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
@ -417,6 +419,7 @@ class Updater(AsyncContextManager["Updater"]):
on_err_cb=error_callback or default_error_callback,
description="getting Updates",
interval=poll_interval,
stop_event=self.__polling_task_stop_event,
),
name="Updater:start_polling:polling_task",
)
@ -693,6 +696,7 @@ class Updater(AsyncContextManager["Updater"]):
on_err_cb: Callable[[TelegramError], None],
description: str,
interval: float,
stop_event: Optional[asyncio.Event],
) -> None:
"""Perform a loop calling `action_cb`, retrying after network errors.
@ -706,14 +710,36 @@ class Updater(AsyncContextManager["Updater"]):
description (:obj:`str`): Description text to use for logs and exception raised.
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
`action_cb`.
stop_event (:class:`asyncio.Event` | :obj:`None`): Event to wait on for stopping the
loop. Setting the event will make the loop exit even if `action_cb` is currently
running.
"""
async def do_action() -> bool:
if not stop_event:
return await action_cb()
action_cb_task = asyncio.create_task(action_cb())
stop_task = asyncio.create_task(stop_event.wait())
done, pending = await asyncio.wait(
(action_cb_task, stop_task), return_when=asyncio.FIRST_COMPLETED
)
with contextlib.suppress(asyncio.CancelledError):
for task in pending:
task.cancel()
if stop_task in done:
_LOGGER.debug("Network loop retry %s was cancelled", description)
return False
return action_cb_task.result()
_LOGGER.debug("Start network loop retry %s", description)
cur_interval = interval
try:
while self.running:
try:
if not await action_cb():
if not await do_action():
break
except RetryAfter as exc:
_LOGGER.info("%s", exc)
@ -737,9 +763,6 @@ class Updater(AsyncContextManager["Updater"]):
if cur_interval:
await asyncio.sleep(cur_interval)
except asyncio.CancelledError:
_LOGGER.debug("Network loop retry %s was cancelled", description)
async def _bootstrap(
self,
max_retries: int,
@ -804,6 +827,7 @@ class Updater(AsyncContextManager["Updater"]):
bootstrap_on_err_cb,
"bootstrap del webhook",
bootstrap_interval,
stop_event=None,
)
# Reset the retries counter for the next _network_loop_retry call
@ -817,6 +841,7 @@ class Updater(AsyncContextManager["Updater"]):
bootstrap_on_err_cb,
"bootstrap set webhook",
bootstrap_interval,
stop_event=None,
)
async def stop(self) -> None:
@ -852,7 +877,7 @@ class Updater(AsyncContextManager["Updater"]):
"""Stops the polling task by awaiting it."""
if self.__polling_task:
_LOGGER.debug("Waiting background polling task to finish up.")
self.__polling_task.cancel()
self.__polling_task_stop_event.set()
with contextlib.suppress(asyncio.CancelledError):
await self.__polling_task
@ -860,6 +885,7 @@ class Updater(AsyncContextManager["Updater"]):
# after start_polling(), but lets better be safe than sorry ...
self.__polling_task = None
self.__polling_task_stop_event.clear()
if self.__polling_cleanup_cb:
await self.__polling_cleanup_cb()

View file

@ -1432,6 +1432,7 @@ class TestApplication:
)
def test_run_polling_basic(self, app, monkeypatch, caplog):
exception_event = threading.Event()
exception_testing_done = threading.Event()
update_event = threading.Event()
exception = TelegramError("This is a test error")
assertions = {}
@ -1439,8 +1440,14 @@ class TestApplication:
async def get_updates(*args, **kwargs):
if exception_event.is_set():
raise exception
# This makes sure that other coroutines have a chance of running as well
await asyncio.sleep(0)
if exception_testing_done.is_set() and app.updater.running:
# the longer sleep makes sure that we can exit also while get_updates is running
await asyncio.sleep(20)
else:
await asyncio.sleep(0.01)
update_event.set()
return [self.message_update]
@ -1466,10 +1473,12 @@ class TestApplication:
exception_event.set()
time.sleep(0.05)
assertions["exception_handling"] = self.received == exception.message
exception_testing_done.set()
# So that the get_updates call on shutdown doesn't fail
exception_event.clear()
time.sleep(1)
os.kill(os.getpid(), signal.SIGINT)
time.sleep(0.1)

View file

@ -234,7 +234,7 @@ class TestUpdater:
updates.task_done()
return [next_update]
await asyncio.sleep(0)
await asyncio.sleep(0.1)
return []
orig_del_webhook = updater.bot.delete_webhook
@ -520,10 +520,13 @@ class TestUpdater:
):
raise_exception = True
get_updates_event = asyncio.Event()
second_get_updates_event = asyncio.Event()
async def get_updates(*args, **kwargs):
# So that the main task has a chance to be called
await asyncio.sleep(0)
if get_updates_event.is_set():
second_get_updates_event.set()
if not raise_exception:
return []
@ -548,6 +551,9 @@ class TestUpdater:
# Also makes sure that the error handler was called
await get_updates_event.wait()
# wait for get_updates to be called a second time - only now we can expect that
# all error handling for the previous call has finished
await second_get_updates_event.wait()
if callback_should_be_called:
# Make sure that the error handler was called
@ -588,16 +594,13 @@ class TestUpdater:
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
update_queue = asyncio.Queue()
await update_queue.put(Update(update_id=1))
await update_queue.put(Update(update_id=2))
first_update_event = asyncio.Event()
second_update_event = asyncio.Event()
async def get_updates(*args, **kwargs):
self.message_count = kwargs.get("offset")
update = await update_queue.get()
if update.update_id == 1:
first_update_event.set()
else:
await second_update_event.wait()
return [update]
@ -611,8 +614,8 @@ class TestUpdater:
# Unfortunately we need to use the private attribute here to produce the problem
updater._running = False
second_update_event.set()
await asyncio.sleep(1)
await asyncio.sleep(0.1)
assert caplog.records
assert any(
"Updater stopped unexpectedly." in record.getMessage()
@ -621,7 +624,7 @@ class TestUpdater:
)
# Make sure that the update_id offset wasn't increased
assert self.message_count == 2
assert self.message_count < 1
async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
# Unfortunately we have to use some internal logic to trigger an exception