Add Application.post_shutdown (#3126)

Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com>
This commit is contained in:
Alex 2022-07-03 15:22:50 +02:00 committed by GitHub
parent 2ecb8d5413
commit 1f0f6a8d3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 186 additions and 2 deletions

View file

@ -182,6 +182,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
post_init (:term:`coroutine function`): Optional. A callback that will be executed by
:meth:`Application.run_polling` and :meth:`Application.run_webhook` after initializing
the application via :meth:`initialize`.
post_shutdown (:term:`coroutine function`): Optional. A callback that will be executed by
:meth:`Application.run_polling` and :meth:`Application.run_webhook` after shutting down
the application via :meth:`shutdown`.
"""
@ -213,6 +216,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
"job_queue",
"persistence",
"post_init",
"post_shutdown",
"update_queue",
"updater",
"user_data",
@ -231,6 +235,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
post_init: Optional[
Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]]
],
post_shutdown: Optional[
Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]]
],
):
if not was_called_by(
inspect.currentframe(), Path(__file__).parent.resolve() / "_applicationbuilder.py"
@ -248,6 +255,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
self.handlers: Dict[int, List[BaseHandler]] = {}
self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {}
self.post_init = post_init
self.post_shutdown = post_shutdown
if isinstance(concurrent_updates, int) and concurrent_updates < 0:
raise ValueError("`concurrent_updates` must be a non-negative integer!")
@ -362,6 +370,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
* :attr:`persistence` by calling :meth:`update_persistence` and
:meth:`BasePersistence.flush`
Does *not* call :attr:`post_shutdown` - that is only done by :meth:`run_polling` and
:meth:`run_webhook`.
.. seealso::
:meth:`initialize`
@ -573,6 +584,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
If :attr:`post_init` is set, it will be called between :meth:`initialize` and
:meth:`telegram.ext.Updater.start_polling`.
If :attr:`post_shutdown` is set, it will be called after both :meth:`shutdown`
and :meth:`telegram.ext.Updater.shutdown`.
.. seealso::
:meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown`
:meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook`
@ -683,6 +697,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
If :attr:`post_init` is set, it will be called between :meth:`initialize` and
:meth:`telegram.ext.Updater.start_webhook`.
If :attr:`post_shutdown` is set, it will be called after both :meth:`shutdown`
and :meth:`telegram.ext.Updater.shutdown`.
.. seealso::
:meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown`
:meth:`telegram.ext.Updater.start_webhook`, :meth:`run_polling`
@ -813,7 +830,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
if self.running:
loop.run_until_complete(self.stop())
loop.run_until_complete(self.shutdown())
loop.run_until_complete(self.updater.shutdown()) # type: ignore[union-attr]
if self.post_shutdown:
loop.run_until_complete(self.post_shutdown(self))
finally:
if close_loop:
loop.close()

View file

@ -135,6 +135,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_persistence",
"_pool_timeout",
"_post_init",
"_post_shutdown",
"_private_key",
"_private_key_password",
"_proxy_url",
@ -178,6 +179,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._concurrent_updates: DVInput[Union[int, bool]] = DEFAULT_FALSE
self._updater: ODVInput[Updater] = DEFAULT_NONE
self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
self._post_shutdown: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
def _build_request(self, get_updates: bool) -> BaseRequest:
prefix = "_get_updates_" if get_updates else "_"
@ -271,6 +273,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
persistence=persistence,
context_types=DefaultValue.get_value(self._context_types),
post_init=self._post_init,
post_shutdown=self._post_shutdown,
**self._application_kwargs, # For custom Application subclasses
)
@ -934,6 +937,42 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._post_init = post_init
return self
def post_shutdown(
self: BuilderType, post_shutdown: Callable[[Application], Coroutine[Any, Any, None]]
) -> BuilderType:
"""
Sets a callback to be executed by :meth:`Application.run_polling` and
:meth:`Application.run_webhook` *after* executing :meth:`Updater.shutdown`
and :meth:`Application.shutdown`.
Tip:
This can be used for custom shutdown logic that requires to await coroutines, e.g.
closing a database connection
Example:
.. code::
async def post_shutdown(application: Application) -> None:
await application.bot_data['database'].close()
application = Application.builder()
.token("TOKEN")
.post_shutdown(post_shutdown)
.build()
Args:
post_shutdown (:term:`coroutine function`): The custom callback. Must be a
:term:`coroutine function` and must accept exactly one positional argument, which
is the :class:`~telegram.ext.Application`::
async def post_shutdown(application: Application) -> None:
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._post_shutdown = post_shutdown
return self
InitApplicationBuilder = ( # This is defined all the way down here so that its type is inferred
ApplicationBuilder[ # by Pylance correctly.

View file

@ -130,6 +130,7 @@ class TestApplication:
updater=updater,
concurrent_updates=False,
post_init=None,
post_shutdown=None,
)
assert len(recwarn) == 1
assert (
@ -152,6 +153,9 @@ class TestApplication:
async def post_init(application: Application) -> None:
pass
async def post_shutdown(application: Application) -> None:
pass
app = Application(
bot=bot,
update_queue=update_queue,
@ -161,6 +165,7 @@ class TestApplication:
updater=updater,
concurrent_updates=concurrent_updates,
post_init=post_init,
post_shutdown=post_shutdown,
)
assert app.bot is bot
assert app.update_queue is update_queue
@ -172,6 +177,7 @@ class TestApplication:
assert app.bot is updater.bot
assert app.concurrent_updates == expected
assert app.post_init is post_init
assert app.post_shutdown is post_shutdown
# These should be done by the builder
assert app.persistence.bot is None
@ -192,6 +198,7 @@ class TestApplication:
updater=updater,
concurrent_updates=-1,
post_init=None,
post_shutdown=None,
)
def test_custom_context_init(self, bot):
@ -1433,6 +1440,52 @@ class TestApplication:
thread.join()
assert events == ["init", "post_init", "start_polling"], "Wrong order of events detected!"
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
def test_run_polling_post_shutdown(self, bot, monkeypatch):
events = []
async def get_updates(*args, **kwargs):
# This makes sure that other coroutines have a chance of running as well
await asyncio.sleep(0)
return []
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
os.kill(os.getpid(), signal.SIGINT)
async def post_shutdown(app: Application) -> None:
events.append("post_shutdown")
app = Application.builder().token(bot.token).post_shutdown(post_shutdown).build()
monkeypatch.setattr(app.bot, "get_updates", get_updates)
monkeypatch.setattr(
app, "shutdown", call_after(app.shutdown, lambda _: events.append("shutdown"))
)
monkeypatch.setattr(
app.updater,
"shutdown",
call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")),
)
thread = Thread(target=thread_target)
thread.start()
app.run_polling(drop_pending_updates=True, close_loop=False)
thread.join()
assert events == [
"updater.shutdown",
"shutdown",
"post_shutdown",
], "Wrong order of events detected!"
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
@ -1620,6 +1673,69 @@ class TestApplication:
thread.join()
assert events == ["init", "post_init", "start_webhook"], "Wrong order of events detected!"
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
)
def test_run_webhook_post_shutdown(self, bot, monkeypatch):
events = []
async def delete_webhook(*args, **kwargs):
return True
async def set_webhook(*args, **kwargs):
return True
async def get_updates(*args, **kwargs):
# This makes sure that other coroutines have a chance of running as well
await asyncio.sleep(0)
return []
def thread_target():
waited = 0
while not app.running:
time.sleep(0.05)
waited += 0.05
if waited > 5:
pytest.fail("App apparently won't start")
os.kill(os.getpid(), signal.SIGINT)
async def post_shutdown(app: Application) -> None:
events.append("post_shutdown")
app = Application.builder().token(bot.token).post_shutdown(post_shutdown).build()
monkeypatch.setattr(app.bot, "set_webhook", set_webhook)
monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook)
monkeypatch.setattr(
app, "shutdown", call_after(app.shutdown, lambda _: events.append("shutdown"))
)
monkeypatch.setattr(
app.updater,
"shutdown",
call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")),
)
thread = Thread(target=thread_target)
thread.start()
ip = "127.0.0.1"
port = randrange(1024, 49152)
app.run_webhook(
ip_address=ip,
port=port,
url_path="TOKEN",
drop_pending_updates=True,
close_loop=False,
)
thread.join()
assert events == [
"updater.shutdown",
"shutdown",
"post_shutdown",
], "Wrong order of events detected!"
@pytest.mark.skipif(
platform.system() == "Windows",
reason="Can't send signals without stopping whole process on windows",
@ -1718,7 +1834,12 @@ class TestApplication:
assert not app.running
assert not app.updater.running
assert set(shutdowns) == {"application", "updater"}
if method == "initialize":
# If App.initialize fails, then App.shutdown pretty much does nothing, especially
# doesn't call Updater.shutdown.
assert set(shutdowns) == {"application"}
else:
assert set(shutdowns) == {"application", "updater"}
@pytest.mark.parametrize("method", ["start_polling", "start_webhook"])
@pytest.mark.filterwarnings("ignore::telegram.warnings.PTBUserWarning")

View file

@ -107,6 +107,7 @@ class TestApplicationBuilder:
assert app.persistence is None
assert app.post_init is None
assert app.post_shutdown is None
@pytest.mark.parametrize(
"method, description", _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS]
@ -322,6 +323,9 @@ class TestApplicationBuilder:
async def post_init(app: Application) -> None:
pass
async def post_shutdown(app: Application) -> None:
pass
app = (
builder.token(bot.token)
.job_queue(job_queue)
@ -330,6 +334,7 @@ class TestApplicationBuilder:
.context_types(context_types)
.concurrent_updates(concurrent_updates)
.post_init(post_init)
.post_shutdown(post_shutdown)
).build()
assert app.job_queue is job_queue
assert app.job_queue.application is app
@ -341,6 +346,7 @@ class TestApplicationBuilder:
assert app.context_types is context_types
assert app.concurrent_updates == concurrent_updates
assert app.post_init is post_init
assert app.post_shutdown is post_shutdown
updater = Updater(bot=bot, update_queue=update_queue)
app = ApplicationBuilder().updater(updater).build()