mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Make Updater.stop
Independent of CancelledError
(#4126)
This commit is contained in:
parent
bd9b0bd126
commit
20e0f87f6b
3 changed files with 75 additions and 37 deletions
|
@ -104,6 +104,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
"__lock",
|
||||
"__polling_cleanup_cb",
|
||||
"__polling_task",
|
||||
"__polling_task_stop_event",
|
||||
"_httpd",
|
||||
"_initialized",
|
||||
"_last_update_id",
|
||||
|
@ -126,6 +127,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
self._httpd: Optional[WebhookServer] = None
|
||||
self.__lock = asyncio.Lock()
|
||||
self.__polling_task: Optional[asyncio.Task] = None
|
||||
self.__polling_task_stop_event: asyncio.Event = asyncio.Event()
|
||||
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
|
||||
|
||||
async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
|
||||
|
@ -417,6 +419,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
on_err_cb=error_callback or default_error_callback,
|
||||
description="getting Updates",
|
||||
interval=poll_interval,
|
||||
stop_event=self.__polling_task_stop_event,
|
||||
),
|
||||
name="Updater:start_polling:polling_task",
|
||||
)
|
||||
|
@ -693,6 +696,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
on_err_cb: Callable[[TelegramError], None],
|
||||
description: str,
|
||||
interval: float,
|
||||
stop_event: Optional[asyncio.Event],
|
||||
) -> None:
|
||||
"""Perform a loop calling `action_cb`, retrying after network errors.
|
||||
|
||||
|
@ -706,14 +710,36 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
description (:obj:`str`): Description text to use for logs and exception raised.
|
||||
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
|
||||
`action_cb`.
|
||||
stop_event (:class:`asyncio.Event` | :obj:`None`): Event to wait on for stopping the
|
||||
loop. Setting the event will make the loop exit even if `action_cb` is currently
|
||||
running.
|
||||
|
||||
"""
|
||||
|
||||
async def do_action() -> bool:
|
||||
if not stop_event:
|
||||
return await action_cb()
|
||||
|
||||
action_cb_task = asyncio.create_task(action_cb())
|
||||
stop_task = asyncio.create_task(stop_event.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
(action_cb_task, stop_task), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if stop_task in done:
|
||||
_LOGGER.debug("Network loop retry %s was cancelled", description)
|
||||
return False
|
||||
|
||||
return action_cb_task.result()
|
||||
|
||||
_LOGGER.debug("Start network loop retry %s", description)
|
||||
cur_interval = interval
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
if not await action_cb():
|
||||
if not await do_action():
|
||||
break
|
||||
except RetryAfter as exc:
|
||||
_LOGGER.info("%s", exc)
|
||||
|
@ -737,9 +763,6 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
if cur_interval:
|
||||
await asyncio.sleep(cur_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.debug("Network loop retry %s was cancelled", description)
|
||||
|
||||
async def _bootstrap(
|
||||
self,
|
||||
max_retries: int,
|
||||
|
@ -804,6 +827,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
bootstrap_on_err_cb,
|
||||
"bootstrap del webhook",
|
||||
bootstrap_interval,
|
||||
stop_event=None,
|
||||
)
|
||||
|
||||
# Reset the retries counter for the next _network_loop_retry call
|
||||
|
@ -817,6 +841,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
bootstrap_on_err_cb,
|
||||
"bootstrap set webhook",
|
||||
bootstrap_interval,
|
||||
stop_event=None,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
|
@ -852,7 +877,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
"""Stops the polling task by awaiting it."""
|
||||
if self.__polling_task:
|
||||
_LOGGER.debug("Waiting background polling task to finish up.")
|
||||
self.__polling_task.cancel()
|
||||
self.__polling_task_stop_event.set()
|
||||
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.__polling_task
|
||||
|
@ -860,6 +885,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
# after start_polling(), but lets better be safe than sorry ...
|
||||
|
||||
self.__polling_task = None
|
||||
self.__polling_task_stop_event.clear()
|
||||
|
||||
if self.__polling_cleanup_cb:
|
||||
await self.__polling_cleanup_cb()
|
||||
|
|
|
@ -1432,6 +1432,7 @@ class TestApplication:
|
|||
)
|
||||
def test_run_polling_basic(self, app, monkeypatch, caplog):
|
||||
exception_event = threading.Event()
|
||||
exception_testing_done = threading.Event()
|
||||
update_event = threading.Event()
|
||||
exception = TelegramError("This is a test error")
|
||||
assertions = {}
|
||||
|
@ -1439,8 +1440,14 @@ class TestApplication:
|
|||
async def get_updates(*args, **kwargs):
|
||||
if exception_event.is_set():
|
||||
raise exception
|
||||
|
||||
# This makes sure that other coroutines have a chance of running as well
|
||||
await asyncio.sleep(0)
|
||||
if exception_testing_done.is_set() and app.updater.running:
|
||||
# the longer sleep makes sure that we can exit also while get_updates is running
|
||||
await asyncio.sleep(20)
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
update_event.set()
|
||||
return [self.message_update]
|
||||
|
||||
|
@ -1466,10 +1473,12 @@ class TestApplication:
|
|||
exception_event.set()
|
||||
time.sleep(0.05)
|
||||
assertions["exception_handling"] = self.received == exception.message
|
||||
exception_testing_done.set()
|
||||
|
||||
# So that the get_updates call on shutdown doesn't fail
|
||||
exception_event.clear()
|
||||
|
||||
time.sleep(1)
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
|
|
@ -234,7 +234,7 @@ class TestUpdater:
|
|||
updates.task_done()
|
||||
return [next_update]
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0.1)
|
||||
return []
|
||||
|
||||
orig_del_webhook = updater.bot.delete_webhook
|
||||
|
@ -520,10 +520,13 @@ class TestUpdater:
|
|||
):
|
||||
raise_exception = True
|
||||
get_updates_event = asyncio.Event()
|
||||
second_get_updates_event = asyncio.Event()
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
# So that the main task has a chance to be called
|
||||
await asyncio.sleep(0)
|
||||
if get_updates_event.is_set():
|
||||
second_get_updates_event.set()
|
||||
|
||||
if not raise_exception:
|
||||
return []
|
||||
|
@ -548,6 +551,9 @@ class TestUpdater:
|
|||
|
||||
# Also makes sure that the error handler was called
|
||||
await get_updates_event.wait()
|
||||
# wait for get_updates to be called a second time - only now we can expect that
|
||||
# all error handling for the previous call has finished
|
||||
await second_get_updates_event.wait()
|
||||
|
||||
if callback_should_be_called:
|
||||
# Make sure that the error handler was called
|
||||
|
@ -588,16 +594,13 @@ class TestUpdater:
|
|||
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
|
||||
update_queue = asyncio.Queue()
|
||||
await update_queue.put(Update(update_id=1))
|
||||
await update_queue.put(Update(update_id=2))
|
||||
first_update_event = asyncio.Event()
|
||||
second_update_event = asyncio.Event()
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
self.message_count = kwargs.get("offset")
|
||||
update = await update_queue.get()
|
||||
if update.update_id == 1:
|
||||
first_update_event.set()
|
||||
else:
|
||||
await second_update_event.wait()
|
||||
return [update]
|
||||
|
||||
|
@ -611,8 +614,8 @@ class TestUpdater:
|
|||
# Unfortunately we need to use the private attribute here to produce the problem
|
||||
updater._running = False
|
||||
second_update_event.set()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
assert caplog.records
|
||||
assert any(
|
||||
"Updater stopped unexpectedly." in record.getMessage()
|
||||
|
@ -621,7 +624,7 @@ class TestUpdater:
|
|||
)
|
||||
|
||||
# Make sure that the update_id offset wasn't increased
|
||||
assert self.message_count == 2
|
||||
assert self.message_count < 1
|
||||
|
||||
async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
|
||||
# Unfortunately we have to use some internal logic to trigger an exception
|
||||
|
|
Loading…
Reference in a new issue