Make Updater.stop Independent of CancelledError (#4126)

This commit is contained in:
Bibo-Joshi 2024-03-03 19:22:42 +01:00 committed by GitHub
parent bd9b0bd126
commit 20e0f87f6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 75 additions and 37 deletions

View file

@ -104,6 +104,7 @@ class Updater(AsyncContextManager["Updater"]):
"__lock", "__lock",
"__polling_cleanup_cb", "__polling_cleanup_cb",
"__polling_task", "__polling_task",
"__polling_task_stop_event",
"_httpd", "_httpd",
"_initialized", "_initialized",
"_last_update_id", "_last_update_id",
@ -126,6 +127,7 @@ class Updater(AsyncContextManager["Updater"]):
self._httpd: Optional[WebhookServer] = None self._httpd: Optional[WebhookServer] = None
self.__lock = asyncio.Lock() self.__lock = asyncio.Lock()
self.__polling_task: Optional[asyncio.Task] = None self.__polling_task: Optional[asyncio.Task] = None
self.__polling_task_stop_event: asyncio.Event = asyncio.Event()
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019 async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
@ -417,6 +419,7 @@ class Updater(AsyncContextManager["Updater"]):
on_err_cb=error_callback or default_error_callback, on_err_cb=error_callback or default_error_callback,
description="getting Updates", description="getting Updates",
interval=poll_interval, interval=poll_interval,
stop_event=self.__polling_task_stop_event,
), ),
name="Updater:start_polling:polling_task", name="Updater:start_polling:polling_task",
) )
@ -693,6 +696,7 @@ class Updater(AsyncContextManager["Updater"]):
on_err_cb: Callable[[TelegramError], None], on_err_cb: Callable[[TelegramError], None],
description: str, description: str,
interval: float, interval: float,
stop_event: Optional[asyncio.Event],
) -> None: ) -> None:
"""Perform a loop calling `action_cb`, retrying after network errors. """Perform a loop calling `action_cb`, retrying after network errors.
@ -706,39 +710,58 @@ class Updater(AsyncContextManager["Updater"]):
description (:obj:`str`): Description text to use for logs and exception raised. description (:obj:`str`): Description text to use for logs and exception raised.
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
`action_cb`. `action_cb`.
stop_event (:class:`asyncio.Event` | :obj:`None`): Event to wait on for stopping the
loop. Setting the event will make the loop exit even if `action_cb` is currently
running.
""" """
async def do_action() -> bool:
if not stop_event:
return await action_cb()
action_cb_task = asyncio.create_task(action_cb())
stop_task = asyncio.create_task(stop_event.wait())
done, pending = await asyncio.wait(
(action_cb_task, stop_task), return_when=asyncio.FIRST_COMPLETED
)
with contextlib.suppress(asyncio.CancelledError):
for task in pending:
task.cancel()
if stop_task in done:
_LOGGER.debug("Network loop retry %s was cancelled", description)
return False
return action_cb_task.result()
_LOGGER.debug("Start network loop retry %s", description) _LOGGER.debug("Start network loop retry %s", description)
cur_interval = interval cur_interval = interval
try: while self.running:
while self.running: try:
try: if not await do_action():
if not await action_cb(): break
break except RetryAfter as exc:
except RetryAfter as exc: _LOGGER.info("%s", exc)
_LOGGER.info("%s", exc) cur_interval = 0.5 + exc.retry_after
cur_interval = 0.5 + exc.retry_after except TimedOut as toe:
except TimedOut as toe: _LOGGER.debug("Timed out %s: %s", description, toe)
_LOGGER.debug("Timed out %s: %s", description, toe) # If failure is due to timeout, we should retry asap.
# If failure is due to timeout, we should retry asap. cur_interval = 0
cur_interval = 0 except InvalidToken as pex:
except InvalidToken as pex: _LOGGER.error("Invalid token; aborting")
_LOGGER.error("Invalid token; aborting") raise pex
raise pex except TelegramError as telegram_exc:
except TelegramError as telegram_exc: _LOGGER.error("Error while %s: %s", description, telegram_exc)
_LOGGER.error("Error while %s: %s", description, telegram_exc) on_err_cb(telegram_exc)
on_err_cb(telegram_exc)
# increase waiting times on subsequent errors up to 30secs # increase waiting times on subsequent errors up to 30secs
cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval) cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval)
else: else:
cur_interval = interval cur_interval = interval
if cur_interval: if cur_interval:
await asyncio.sleep(cur_interval) await asyncio.sleep(cur_interval)
except asyncio.CancelledError:
_LOGGER.debug("Network loop retry %s was cancelled", description)
async def _bootstrap( async def _bootstrap(
self, self,
@ -804,6 +827,7 @@ class Updater(AsyncContextManager["Updater"]):
bootstrap_on_err_cb, bootstrap_on_err_cb,
"bootstrap del webhook", "bootstrap del webhook",
bootstrap_interval, bootstrap_interval,
stop_event=None,
) )
# Reset the retries counter for the next _network_loop_retry call # Reset the retries counter for the next _network_loop_retry call
@ -817,6 +841,7 @@ class Updater(AsyncContextManager["Updater"]):
bootstrap_on_err_cb, bootstrap_on_err_cb,
"bootstrap set webhook", "bootstrap set webhook",
bootstrap_interval, bootstrap_interval,
stop_event=None,
) )
async def stop(self) -> None: async def stop(self) -> None:
@ -852,7 +877,7 @@ class Updater(AsyncContextManager["Updater"]):
"""Stops the polling task by awaiting it.""" """Stops the polling task by awaiting it."""
if self.__polling_task: if self.__polling_task:
_LOGGER.debug("Waiting background polling task to finish up.") _LOGGER.debug("Waiting background polling task to finish up.")
self.__polling_task.cancel() self.__polling_task_stop_event.set()
with contextlib.suppress(asyncio.CancelledError): with contextlib.suppress(asyncio.CancelledError):
await self.__polling_task await self.__polling_task
@ -860,6 +885,7 @@ class Updater(AsyncContextManager["Updater"]):
# after start_polling(), but lets better be safe than sorry ... # after start_polling(), but lets better be safe than sorry ...
self.__polling_task = None self.__polling_task = None
self.__polling_task_stop_event.clear()
if self.__polling_cleanup_cb: if self.__polling_cleanup_cb:
await self.__polling_cleanup_cb() await self.__polling_cleanup_cb()

View file

@ -1432,6 +1432,7 @@ class TestApplication:
) )
def test_run_polling_basic(self, app, monkeypatch, caplog): def test_run_polling_basic(self, app, monkeypatch, caplog):
exception_event = threading.Event() exception_event = threading.Event()
exception_testing_done = threading.Event()
update_event = threading.Event() update_event = threading.Event()
exception = TelegramError("This is a test error") exception = TelegramError("This is a test error")
assertions = {} assertions = {}
@ -1439,8 +1440,14 @@ class TestApplication:
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
if exception_event.is_set(): if exception_event.is_set():
raise exception raise exception
# This makes sure that other coroutines have a chance of running as well # This makes sure that other coroutines have a chance of running as well
await asyncio.sleep(0) if exception_testing_done.is_set() and app.updater.running:
# the longer sleep makes sure that we can exit also while get_updates is running
await asyncio.sleep(20)
else:
await asyncio.sleep(0.01)
update_event.set() update_event.set()
return [self.message_update] return [self.message_update]
@ -1466,10 +1473,12 @@ class TestApplication:
exception_event.set() exception_event.set()
time.sleep(0.05) time.sleep(0.05)
assertions["exception_handling"] = self.received == exception.message assertions["exception_handling"] = self.received == exception.message
exception_testing_done.set()
# So that the get_updates call on shutdown doesn't fail # So that the get_updates call on shutdown doesn't fail
exception_event.clear() exception_event.clear()
time.sleep(1)
os.kill(os.getpid(), signal.SIGINT) os.kill(os.getpid(), signal.SIGINT)
time.sleep(0.1) time.sleep(0.1)

View file

@ -234,7 +234,7 @@ class TestUpdater:
updates.task_done() updates.task_done()
return [next_update] return [next_update]
await asyncio.sleep(0) await asyncio.sleep(0.1)
return [] return []
orig_del_webhook = updater.bot.delete_webhook orig_del_webhook = updater.bot.delete_webhook
@ -520,10 +520,13 @@ class TestUpdater:
): ):
raise_exception = True raise_exception = True
get_updates_event = asyncio.Event() get_updates_event = asyncio.Event()
second_get_updates_event = asyncio.Event()
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
# So that the main task has a chance to be called # So that the main task has a chance to be called
await asyncio.sleep(0) await asyncio.sleep(0)
if get_updates_event.is_set():
second_get_updates_event.set()
if not raise_exception: if not raise_exception:
return [] return []
@ -548,6 +551,9 @@ class TestUpdater:
# Also makes sure that the error handler was called # Also makes sure that the error handler was called
await get_updates_event.wait() await get_updates_event.wait()
# wait for get_updates to be called a second time - only now we can expect that
# all error handling for the previous call has finished
await second_get_updates_event.wait()
if callback_should_be_called: if callback_should_be_called:
# Make sure that the error handler was called # Make sure that the error handler was called
@ -588,17 +594,14 @@ class TestUpdater:
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog): async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
update_queue = asyncio.Queue() update_queue = asyncio.Queue()
await update_queue.put(Update(update_id=1)) await update_queue.put(Update(update_id=1))
await update_queue.put(Update(update_id=2))
first_update_event = asyncio.Event() first_update_event = asyncio.Event()
second_update_event = asyncio.Event() second_update_event = asyncio.Event()
async def get_updates(*args, **kwargs): async def get_updates(*args, **kwargs):
self.message_count = kwargs.get("offset") self.message_count = kwargs.get("offset")
update = await update_queue.get() update = await update_queue.get()
if update.update_id == 1: first_update_event.set()
first_update_event.set() await second_update_event.wait()
else:
await second_update_event.wait()
return [update] return [update]
monkeypatch.setattr(updater.bot, "get_updates", get_updates) monkeypatch.setattr(updater.bot, "get_updates", get_updates)
@ -611,8 +614,8 @@ class TestUpdater:
# Unfortunately we need to use the private attribute here to produce the problem # Unfortunately we need to use the private attribute here to produce the problem
updater._running = False updater._running = False
second_update_event.set() second_update_event.set()
await asyncio.sleep(1)
await asyncio.sleep(0.1)
assert caplog.records assert caplog.records
assert any( assert any(
"Updater stopped unexpectedly." in record.getMessage() "Updater stopped unexpectedly." in record.getMessage()
@ -621,7 +624,7 @@ class TestUpdater:
) )
# Make sure that the update_id offset wasn't increased # Make sure that the update_id offset wasn't increased
assert self.message_count == 2 assert self.message_count < 1
async def test_start_polling_not_running_after_failure(self, updater, monkeypatch): async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
# Unfortunately we have to use some internal logic to trigger an exception # Unfortunately we have to use some internal logic to trigger an exception