Shield Update Fetcher Task in Application.start (#3657)

This commit is contained in:
Bibo-Joshi 2023-05-06 21:10:12 +02:00 committed by GitHub
parent 87a6890900
commit 450dc2115c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 14 deletions

View file

@ -1046,24 +1046,33 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
async def _update_fetcher(self) -> None:
# Continuously fetch updates from the queue. Exit only once the signal object is found.
while True:
update = await self.update_queue.get()
try:
update = await self.update_queue.get()
if update is _STOP_SIGNAL:
_LOGGER.debug("Dropping pending updates")
while not self.update_queue.empty():
if update is _STOP_SIGNAL:
_LOGGER.debug("Dropping pending updates")
while not self.update_queue.empty():
self.update_queue.task_done()
# For the _STOP_SIGNAL
self.update_queue.task_done()
return
# For the _STOP_SIGNAL
self.update_queue.task_done()
return
_LOGGER.debug("Processing update %s", update)
_LOGGER.debug("Processing update %s", update)
if self._concurrent_updates:
# We don't await the below because it has to be run concurrently
self.create_task(self.__process_update_wrapper(update), update=update)
else:
await self.__process_update_wrapper(update)
if self._concurrent_updates:
# We don't await the below because it has to be run concurrently
self.create_task(self.__process_update_wrapper(update), update=update)
else:
await self.__process_update_wrapper(update)
except asyncio.CancelledError:
# This may happen if the application is manually run via application.start() and
# then a KeyboardInterrupt is sent. We must prevent this loop to die since
# application.stop() will wait for it's clean shutdown.
_LOGGER.warning(
"Fetching updates got a asyncio.CancelledError. Ignoring as this task may only"
"be closed via `Application.stop`."
)
async def __process_update_wrapper(self, update: object) -> None:
async with self._concurrent_updates_sem:

View file

@ -1975,6 +1975,76 @@ class TestApplication:
assert set(self.received.keys()) == set(expected.keys())
assert self.received == expected
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
async def test_cancellation_error_does_not_stop_polling(
self, one_time_bot, monkeypatch, caplog
):
"""
Ensures that hitting CTRL+C while polling *without* run_polling doesn't kill
the update_fetcher loop such that a shutdown is still possible.
This test is far from perfect, but it's the closest we can come with sane effort.
"""
async def get_updates(*args, **kwargs):
await asyncio.sleep(0)
return [None]
monkeypatch.setattr(one_time_bot, "get_updates", get_updates)
app = ApplicationBuilder().bot(one_time_bot).build()
original_get = app.update_queue.get
raise_cancelled_error = threading.Event()
async def get(*arg, **kwargs):
await asyncio.sleep(0.05)
if raise_cancelled_error.is_set():
raise_cancelled_error.clear()
raise asyncio.CancelledError("Mocked CancelledError")
return await original_get(*arg, **kwargs)
monkeypatch.setattr(app.update_queue, "get", get)
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
time.sleep(0.1)
raise_cancelled_error.set()
async with app:
with caplog.at_level(logging.WARNING):
thread = Thread(target=thread_target)
await app.start()
thread.start()
assert thread.is_alive()
raise_cancelled_error.wait()
# The exit should have been caught and the app should still be running
assert not thread.is_alive()
assert app.running
# Explicit shutdown is required
await app.stop()
thread.join()
assert not thread.is_alive()
assert not app.running
# Make sure that we were warned about the necessity of a manual shutdown
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.name == "telegram.ext.Application"
assert record.getMessage().startswith(
"Fetching updates got a asyncio.CancelledError. Ignoring"
)
def test_run_without_updater(self, one_time_bot):
app = ApplicationBuilder().bot(one_time_bot).updater(None).build()