mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Make Updater.stop
Independent of CancelledError
(#4126)
This commit is contained in:
parent
bd9b0bd126
commit
20e0f87f6b
3 changed files with 75 additions and 37 deletions
|
@ -104,6 +104,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
"__lock",
|
"__lock",
|
||||||
"__polling_cleanup_cb",
|
"__polling_cleanup_cb",
|
||||||
"__polling_task",
|
"__polling_task",
|
||||||
|
"__polling_task_stop_event",
|
||||||
"_httpd",
|
"_httpd",
|
||||||
"_initialized",
|
"_initialized",
|
||||||
"_last_update_id",
|
"_last_update_id",
|
||||||
|
@ -126,6 +127,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
self._httpd: Optional[WebhookServer] = None
|
self._httpd: Optional[WebhookServer] = None
|
||||||
self.__lock = asyncio.Lock()
|
self.__lock = asyncio.Lock()
|
||||||
self.__polling_task: Optional[asyncio.Task] = None
|
self.__polling_task: Optional[asyncio.Task] = None
|
||||||
|
self.__polling_task_stop_event: asyncio.Event = asyncio.Event()
|
||||||
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
|
self.__polling_cleanup_cb: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
|
||||||
|
|
||||||
async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
|
async def __aenter__(self: _UpdaterType) -> _UpdaterType: # noqa: PYI019
|
||||||
|
@ -417,6 +419,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
on_err_cb=error_callback or default_error_callback,
|
on_err_cb=error_callback or default_error_callback,
|
||||||
description="getting Updates",
|
description="getting Updates",
|
||||||
interval=poll_interval,
|
interval=poll_interval,
|
||||||
|
stop_event=self.__polling_task_stop_event,
|
||||||
),
|
),
|
||||||
name="Updater:start_polling:polling_task",
|
name="Updater:start_polling:polling_task",
|
||||||
)
|
)
|
||||||
|
@ -693,6 +696,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
on_err_cb: Callable[[TelegramError], None],
|
on_err_cb: Callable[[TelegramError], None],
|
||||||
description: str,
|
description: str,
|
||||||
interval: float,
|
interval: float,
|
||||||
|
stop_event: Optional[asyncio.Event],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Perform a loop calling `action_cb`, retrying after network errors.
|
"""Perform a loop calling `action_cb`, retrying after network errors.
|
||||||
|
|
||||||
|
@ -706,39 +710,58 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
description (:obj:`str`): Description text to use for logs and exception raised.
|
description (:obj:`str`): Description text to use for logs and exception raised.
|
||||||
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
|
interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to
|
||||||
`action_cb`.
|
`action_cb`.
|
||||||
|
stop_event (:class:`asyncio.Event` | :obj:`None`): Event to wait on for stopping the
|
||||||
|
loop. Setting the event will make the loop exit even if `action_cb` is currently
|
||||||
|
running.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def do_action() -> bool:
|
||||||
|
if not stop_event:
|
||||||
|
return await action_cb()
|
||||||
|
|
||||||
|
action_cb_task = asyncio.create_task(action_cb())
|
||||||
|
stop_task = asyncio.create_task(stop_event.wait())
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
(action_cb_task, stop_task), return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
if stop_task in done:
|
||||||
|
_LOGGER.debug("Network loop retry %s was cancelled", description)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return action_cb_task.result()
|
||||||
|
|
||||||
_LOGGER.debug("Start network loop retry %s", description)
|
_LOGGER.debug("Start network loop retry %s", description)
|
||||||
cur_interval = interval
|
cur_interval = interval
|
||||||
try:
|
while self.running:
|
||||||
while self.running:
|
try:
|
||||||
try:
|
if not await do_action():
|
||||||
if not await action_cb():
|
break
|
||||||
break
|
except RetryAfter as exc:
|
||||||
except RetryAfter as exc:
|
_LOGGER.info("%s", exc)
|
||||||
_LOGGER.info("%s", exc)
|
cur_interval = 0.5 + exc.retry_after
|
||||||
cur_interval = 0.5 + exc.retry_after
|
except TimedOut as toe:
|
||||||
except TimedOut as toe:
|
_LOGGER.debug("Timed out %s: %s", description, toe)
|
||||||
_LOGGER.debug("Timed out %s: %s", description, toe)
|
# If failure is due to timeout, we should retry asap.
|
||||||
# If failure is due to timeout, we should retry asap.
|
cur_interval = 0
|
||||||
cur_interval = 0
|
except InvalidToken as pex:
|
||||||
except InvalidToken as pex:
|
_LOGGER.error("Invalid token; aborting")
|
||||||
_LOGGER.error("Invalid token; aborting")
|
raise pex
|
||||||
raise pex
|
except TelegramError as telegram_exc:
|
||||||
except TelegramError as telegram_exc:
|
_LOGGER.error("Error while %s: %s", description, telegram_exc)
|
||||||
_LOGGER.error("Error while %s: %s", description, telegram_exc)
|
on_err_cb(telegram_exc)
|
||||||
on_err_cb(telegram_exc)
|
|
||||||
|
|
||||||
# increase waiting times on subsequent errors up to 30secs
|
# increase waiting times on subsequent errors up to 30secs
|
||||||
cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval)
|
cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval)
|
||||||
else:
|
else:
|
||||||
cur_interval = interval
|
cur_interval = interval
|
||||||
|
|
||||||
if cur_interval:
|
if cur_interval:
|
||||||
await asyncio.sleep(cur_interval)
|
await asyncio.sleep(cur_interval)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
_LOGGER.debug("Network loop retry %s was cancelled", description)
|
|
||||||
|
|
||||||
async def _bootstrap(
|
async def _bootstrap(
|
||||||
self,
|
self,
|
||||||
|
@ -804,6 +827,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
bootstrap_on_err_cb,
|
bootstrap_on_err_cb,
|
||||||
"bootstrap del webhook",
|
"bootstrap del webhook",
|
||||||
bootstrap_interval,
|
bootstrap_interval,
|
||||||
|
stop_event=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the retries counter for the next _network_loop_retry call
|
# Reset the retries counter for the next _network_loop_retry call
|
||||||
|
@ -817,6 +841,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
bootstrap_on_err_cb,
|
bootstrap_on_err_cb,
|
||||||
"bootstrap set webhook",
|
"bootstrap set webhook",
|
||||||
bootstrap_interval,
|
bootstrap_interval,
|
||||||
|
stop_event=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
|
@ -852,7 +877,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
"""Stops the polling task by awaiting it."""
|
"""Stops the polling task by awaiting it."""
|
||||||
if self.__polling_task:
|
if self.__polling_task:
|
||||||
_LOGGER.debug("Waiting background polling task to finish up.")
|
_LOGGER.debug("Waiting background polling task to finish up.")
|
||||||
self.__polling_task.cancel()
|
self.__polling_task_stop_event.set()
|
||||||
|
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self.__polling_task
|
await self.__polling_task
|
||||||
|
@ -860,6 +885,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
# after start_polling(), but lets better be safe than sorry ...
|
# after start_polling(), but lets better be safe than sorry ...
|
||||||
|
|
||||||
self.__polling_task = None
|
self.__polling_task = None
|
||||||
|
self.__polling_task_stop_event.clear()
|
||||||
|
|
||||||
if self.__polling_cleanup_cb:
|
if self.__polling_cleanup_cb:
|
||||||
await self.__polling_cleanup_cb()
|
await self.__polling_cleanup_cb()
|
||||||
|
|
|
@ -1432,6 +1432,7 @@ class TestApplication:
|
||||||
)
|
)
|
||||||
def test_run_polling_basic(self, app, monkeypatch, caplog):
|
def test_run_polling_basic(self, app, monkeypatch, caplog):
|
||||||
exception_event = threading.Event()
|
exception_event = threading.Event()
|
||||||
|
exception_testing_done = threading.Event()
|
||||||
update_event = threading.Event()
|
update_event = threading.Event()
|
||||||
exception = TelegramError("This is a test error")
|
exception = TelegramError("This is a test error")
|
||||||
assertions = {}
|
assertions = {}
|
||||||
|
@ -1439,8 +1440,14 @@ class TestApplication:
|
||||||
async def get_updates(*args, **kwargs):
|
async def get_updates(*args, **kwargs):
|
||||||
if exception_event.is_set():
|
if exception_event.is_set():
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
# This makes sure that other coroutines have a chance of running as well
|
# This makes sure that other coroutines have a chance of running as well
|
||||||
await asyncio.sleep(0)
|
if exception_testing_done.is_set() and app.updater.running:
|
||||||
|
# the longer sleep makes sure that we can exit also while get_updates is running
|
||||||
|
await asyncio.sleep(20)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
update_event.set()
|
update_event.set()
|
||||||
return [self.message_update]
|
return [self.message_update]
|
||||||
|
|
||||||
|
@ -1466,10 +1473,12 @@ class TestApplication:
|
||||||
exception_event.set()
|
exception_event.set()
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
assertions["exception_handling"] = self.received == exception.message
|
assertions["exception_handling"] = self.received == exception.message
|
||||||
|
exception_testing_done.set()
|
||||||
|
|
||||||
# So that the get_updates call on shutdown doesn't fail
|
# So that the get_updates call on shutdown doesn't fail
|
||||||
exception_event.clear()
|
exception_event.clear()
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
os.kill(os.getpid(), signal.SIGINT)
|
os.kill(os.getpid(), signal.SIGINT)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
|
@ -234,7 +234,7 @@ class TestUpdater:
|
||||||
updates.task_done()
|
updates.task_done()
|
||||||
return [next_update]
|
return [next_update]
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.1)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
orig_del_webhook = updater.bot.delete_webhook
|
orig_del_webhook = updater.bot.delete_webhook
|
||||||
|
@ -520,10 +520,13 @@ class TestUpdater:
|
||||||
):
|
):
|
||||||
raise_exception = True
|
raise_exception = True
|
||||||
get_updates_event = asyncio.Event()
|
get_updates_event = asyncio.Event()
|
||||||
|
second_get_updates_event = asyncio.Event()
|
||||||
|
|
||||||
async def get_updates(*args, **kwargs):
|
async def get_updates(*args, **kwargs):
|
||||||
# So that the main task has a chance to be called
|
# So that the main task has a chance to be called
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
if get_updates_event.is_set():
|
||||||
|
second_get_updates_event.set()
|
||||||
|
|
||||||
if not raise_exception:
|
if not raise_exception:
|
||||||
return []
|
return []
|
||||||
|
@ -548,6 +551,9 @@ class TestUpdater:
|
||||||
|
|
||||||
# Also makes sure that the error handler was called
|
# Also makes sure that the error handler was called
|
||||||
await get_updates_event.wait()
|
await get_updates_event.wait()
|
||||||
|
# wait for get_updates to be called a second time - only now we can expect that
|
||||||
|
# all error handling for the previous call has finished
|
||||||
|
await second_get_updates_event.wait()
|
||||||
|
|
||||||
if callback_should_be_called:
|
if callback_should_be_called:
|
||||||
# Make sure that the error handler was called
|
# Make sure that the error handler was called
|
||||||
|
@ -588,17 +594,14 @@ class TestUpdater:
|
||||||
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
|
async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog):
|
||||||
update_queue = asyncio.Queue()
|
update_queue = asyncio.Queue()
|
||||||
await update_queue.put(Update(update_id=1))
|
await update_queue.put(Update(update_id=1))
|
||||||
await update_queue.put(Update(update_id=2))
|
|
||||||
first_update_event = asyncio.Event()
|
first_update_event = asyncio.Event()
|
||||||
second_update_event = asyncio.Event()
|
second_update_event = asyncio.Event()
|
||||||
|
|
||||||
async def get_updates(*args, **kwargs):
|
async def get_updates(*args, **kwargs):
|
||||||
self.message_count = kwargs.get("offset")
|
self.message_count = kwargs.get("offset")
|
||||||
update = await update_queue.get()
|
update = await update_queue.get()
|
||||||
if update.update_id == 1:
|
first_update_event.set()
|
||||||
first_update_event.set()
|
await second_update_event.wait()
|
||||||
else:
|
|
||||||
await second_update_event.wait()
|
|
||||||
return [update]
|
return [update]
|
||||||
|
|
||||||
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
|
monkeypatch.setattr(updater.bot, "get_updates", get_updates)
|
||||||
|
@ -611,8 +614,8 @@ class TestUpdater:
|
||||||
# Unfortunately we need to use the private attribute here to produce the problem
|
# Unfortunately we need to use the private attribute here to produce the problem
|
||||||
updater._running = False
|
updater._running = False
|
||||||
second_update_event.set()
|
second_update_event.set()
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
assert caplog.records
|
assert caplog.records
|
||||||
assert any(
|
assert any(
|
||||||
"Updater stopped unexpectedly." in record.getMessage()
|
"Updater stopped unexpectedly." in record.getMessage()
|
||||||
|
@ -621,7 +624,7 @@ class TestUpdater:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure that the update_id offset wasn't increased
|
# Make sure that the update_id offset wasn't increased
|
||||||
assert self.message_count == 2
|
assert self.message_count < 1
|
||||||
|
|
||||||
async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
|
async def test_start_polling_not_running_after_failure(self, updater, monkeypatch):
|
||||||
# Unfortunately we have to use some internal logic to trigger an exception
|
# Unfortunately we have to use some internal logic to trigger an exception
|
||||||
|
|
Loading…
Reference in a new issue