From 19302bce2588165c687f51906b75c8c6dc059b37 Mon Sep 17 00:00:00 2001 From: Aditya <69784758+clot27@users.noreply.github.com> Date: Wed, 4 Jan 2023 21:18:48 +0530 Subject: [PATCH] Add `Application(Builder).post_stop` (#3466) --- telegram/ext/_application.py | 50 ++++++++--- telegram/ext/_applicationbuilder.py | 47 +++++++++++ tests/test_application.py | 126 ++++++++++++++++++++++++++++ tests/test_applicationbuilder.py | 6 ++ 4 files changed, 218 insertions(+), 11 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 529a1ff38..07056003a 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -204,6 +204,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) post_shutdown (:term:`coroutine function`): Optional. A callback that will be executed by :meth:`Application.run_polling` and :meth:`Application.run_webhook` after shutting down the application via :meth:`shutdown`. + post_stop (:term:`coroutine function`): Optional. A callback that will be executed by + :meth:`Application.run_polling` and :meth:`Application.run_webhook` after stopping + the application via :meth:`stop`. + + .. versionadded:: 20.1 """ @@ -236,6 +241,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) "persistence", "post_init", "post_shutdown", + "post_stop", "update_queue", "updater", "user_data", @@ -257,6 +263,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) post_shutdown: Optional[ Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]] ], + post_stop: Optional[ + Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]] + ], ): if not was_called_by( inspect.currentframe(), Path(__file__).parent.resolve() / "_applicationbuilder.py" @@ -274,6 +283,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {} self.post_init = post_init self.post_shutdown = post_shutdown + self.post_stop = post_stop if isinstance(concurrent_updates, int) and concurrent_updates < 0: raise ValueError("`concurrent_updates` must be a non-negative integer!") @@ -564,9 +574,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) :meth:`start` Note: - This does *not* stop :attr:`updater`. You need to either manually call - :meth:`telegram.ext.Updater.stop` or use one of :meth:`run_polling` or - :meth:`run_webhook`. + * This does *not* stop :attr:`updater`. You need to either manually call + :meth:`telegram.ext.Updater.stop` or use one of :meth:`run_polling` or + :meth:`run_webhook`. + * Does *not* call :attr:`post_stop` - that is only done by + :meth:`run_polling` and :meth:`run_webhook`. Raises: :exc:`RuntimeError`: If the application is not running. @@ -624,11 +636,18 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) On unix, the app will also shut down on receiving the signals specified by :paramref:`stop_signals`. - If :attr:`post_init` is set, it will be called between :meth:`initialize` and - :meth:`telegram.ext.Updater.start_polling`. + The order of execution by `run_polling` is roughly as follows: - If :attr:`post_shutdown` is set, it will be called after both :meth:`shutdown` - and :meth:`telegram.ext.Updater.shutdown`. + - :meth:`initialize` + - :meth:`post_init` + - :meth:`telegram.ext.Updater.start_polling` + - :meth:`start` + - Run the application until the users stops it + - :meth:`telegram.ext.Updater.stop` + - :meth:`stop` + - :meth:`post_stop` + - :meth:`shutdown` + - :meth:`post_shutdown` .. include:: inclusions/application_run_tip.rst @@ -740,11 +759,18 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) ``https://listen:port/url_path``. Also calls :meth:`telegram.Bot.set_webhook` as required. - If :attr:`post_init` is set, it will be called between :meth:`initialize` and - :meth:`telegram.ext.Updater.start_webhook`. + The order of execution by `run_webhook` is roughly as follows: - If :attr:`post_shutdown` is set, it will be called after both :meth:`shutdown` - and :meth:`telegram.ext.Updater.shutdown`. + - :meth:`initialize` + - :meth:`post_init` + - :meth:`telegram.ext.Updater.start_webhook` + - :meth:`start` + - Run the application until the users stops it + - :meth:`telegram.ext.Updater.stop` + - :meth:`stop` + - :meth:`post_stop` + - :meth:`shutdown` + - :meth:`post_shutdown` Important: If you want to use this method, you must install PTB with the optional requirement @@ -887,6 +913,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) loop.run_until_complete(self.updater.stop()) # type: ignore[union-attr] if self.running: loop.run_until_complete(self.stop()) + if self.post_stop: + loop.run_until_complete(self.post_stop(self)) loop.run_until_complete(self.shutdown()) if self.post_shutdown: loop.run_until_complete(self.post_shutdown(self)) diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index 999050d9a..c2afacb54 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -142,6 +142,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): "_pool_timeout", "_post_init", "_post_shutdown", + "_post_stop", "_private_key", "_private_key_password", "_proxy_url", @@ -196,6 +197,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._updater: ODVInput[Updater] = DEFAULT_NONE self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None self._post_shutdown: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None + self._post_stop: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None self._rate_limiter: ODVInput["BaseRateLimiter"] = DEFAULT_NONE def _build_request(self, get_updates: bool) -> BaseRequest: @@ -301,6 +303,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): context_types=DefaultValue.get_value(self._context_types), post_init=self._post_init, post_shutdown=self._post_shutdown, + post_stop=self._post_stop, **self._application_kwargs, # For custom Application subclasses ) @@ -967,6 +970,8 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): application = Application.builder().token("TOKEN").post_init(post_init).build() + .. seealso:: :meth:`post_stop`, :meth:`post_shutdown` + Args: post_init (:term:`coroutine function`): The custom callback. Must be a :term:`coroutine function` and must accept exactly one positional argument, which @@ -1003,6 +1008,8 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): .post_shutdown(post_shutdown) .build() + .. seealso:: :meth:`post_init`, :meth:`post_stop` + Args: post_shutdown (:term:`coroutine function`): The custom callback. Must be a :term:`coroutine function` and must accept exactly one positional argument, which @@ -1016,6 +1023,46 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._post_shutdown = post_shutdown return self + def post_stop( + self: BuilderType, post_stop: Callable[[Application], Coroutine[Any, Any, None]] + ) -> BuilderType: + """ + Sets a callback to be executed by :meth:`Application.run_polling` and + :meth:`Application.run_webhook` *after* executing :meth:`Updater.stop` + and :meth:`Application.stop`. + + .. versionadded:: 20.1 + + Tip: + This can be used for custom stop logic that requires to await coroutines, e.g. + sending message to a chat before shutting down the bot + + Example: + .. code:: + + async def post_stop(application: Application) -> None: + await application.bot.send_message(123456, "Shutting down...") + + application = Application.builder() + .token("TOKEN") + .post_stop(post_stop) + .build() + + .. seealso:: :meth:`post_init`, :meth:`post_shutdown` + + Args: + post_stop (:term:`coroutine function`): The custom callback. Must be a + :term:`coroutine function` and must accept exactly one positional argument, which + is the :class:`~telegram.ext.Application`:: + + async def post_stop(application: Application) -> None: + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ + self._post_stop = post_stop + return self + def rate_limiter( self: "ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]", rate_limiter: "BaseRateLimiter[RLARGS]", diff --git a/tests/test_application.py b/tests/test_application.py index 8ba30c3c7..a458f6ae7 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -131,6 +131,7 @@ class TestApplication: concurrent_updates=False, post_init=None, post_shutdown=None, + post_stop=None, ) assert len(recwarn) == 1 assert ( @@ -156,6 +157,9 @@ class TestApplication: async def post_shutdown(application: Application) -> None: pass + async def post_stop(application: Application) -> None: + pass + app = Application( bot=bot, update_queue=update_queue, @@ -166,6 +170,7 @@ class TestApplication: concurrent_updates=concurrent_updates, post_init=post_init, post_shutdown=post_shutdown, + post_stop=post_stop, ) assert app.bot is bot assert app.update_queue is update_queue @@ -178,6 +183,7 @@ class TestApplication: assert app.concurrent_updates == expected assert app.post_init is post_init assert app.post_shutdown is post_shutdown + assert app.post_stop is post_stop # These should be done by the builder assert app.persistence.bot is None @@ -199,6 +205,7 @@ class TestApplication: concurrent_updates=-1, post_init=None, post_shutdown=None, + post_stop=None, ) def test_job_queue(self, bot, app, recwarn): @@ -1501,6 +1508,57 @@ class TestApplication: "post_shutdown", ], "Wrong order of events detected!" + @pytest.mark.skipif( + platform.system() == "Windows", + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_polling_post_stop(self, bot, monkeypatch): + events = [] + + async def get_updates(*args, **kwargs): + # This makes sure that other coroutines have a chance of running as well + await asyncio.sleep(0) + return [] + + def thread_target(): + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + + os.kill(os.getpid(), signal.SIGINT) + + async def post_stop(app: Application) -> None: + events.append("post_stop") + + app = Application.builder().token(bot.token).post_stop(post_stop).build() + app.bot._unfreeze() + monkeypatch.setattr(app.bot, "get_updates", get_updates) + monkeypatch.setattr(app, "stop", call_after(app.stop, lambda _: events.append("stop"))) + monkeypatch.setattr( + app.updater, + "stop", + call_after(app.updater.stop, lambda _: events.append("updater.stop")), + ) + monkeypatch.setattr( + app.updater, + "shutdown", + call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), + ) + + thread = Thread(target=thread_target) + thread.start() + app.run_polling(drop_pending_updates=True, close_loop=False) + thread.join() + assert events == [ + "updater.stop", + "stop", + "post_stop", + "updater.shutdown", + ], "Wrong order of events detected!" + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", @@ -1753,6 +1811,74 @@ class TestApplication: "post_shutdown", ], "Wrong order of events detected!" + @pytest.mark.skipif( + platform.system() == "Windows", + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_webhook_post_stop(self, bot, monkeypatch): + events = [] + + async def delete_webhook(*args, **kwargs): + return True + + async def set_webhook(*args, **kwargs): + return True + + async def get_updates(*args, **kwargs): + # This makes sure that other coroutines have a chance of running as well + await asyncio.sleep(0) + return [] + + def thread_target(): + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + + os.kill(os.getpid(), signal.SIGINT) + + async def post_stop(app: Application) -> None: + events.append("post_stop") + + app = Application.builder().token(bot.token).post_stop(post_stop).build() + app.bot._unfreeze() + monkeypatch.setattr(app.bot, "set_webhook", set_webhook) + monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook) + monkeypatch.setattr(app, "stop", call_after(app.stop, lambda _: events.append("stop"))) + monkeypatch.setattr( + app.updater, + "stop", + call_after(app.updater.stop, lambda _: events.append("updater.stop")), + ) + monkeypatch.setattr( + app.updater, + "shutdown", + call_after(app.updater.shutdown, lambda _: events.append("updater.shutdown")), + ) + + thread = Thread(target=thread_target) + thread.start() + + ip = "127.0.0.1" + port = randrange(1024, 49152) + + app.run_webhook( + ip_address=ip, + port=port, + url_path="TOKEN", + drop_pending_updates=True, + close_loop=False, + ) + thread.join() + assert events == [ + "updater.stop", + "stop", + "post_stop", + "updater.shutdown", + ], "Wrong order of events detected!" + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", diff --git a/tests/test_applicationbuilder.py b/tests/test_applicationbuilder.py index 4d0b40dd6..3f1996bd7 100644 --- a/tests/test_applicationbuilder.py +++ b/tests/test_applicationbuilder.py @@ -135,6 +135,7 @@ class TestApplicationBuilder: assert app.persistence is None assert app.post_init is None assert app.post_shutdown is None + assert app.post_stop is None @pytest.mark.parametrize( "method, description", _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS] @@ -361,6 +362,9 @@ class TestApplicationBuilder: async def post_shutdown(app: Application) -> None: pass + async def post_stop(app: Application) -> None: + pass + app = ( builder.token(bot.token) .job_queue(job_queue) @@ -370,6 +374,7 @@ class TestApplicationBuilder: .concurrent_updates(concurrent_updates) .post_init(post_init) .post_shutdown(post_shutdown) + .post_stop(post_stop) .arbitrary_callback_data(True) ).build() assert app.job_queue is job_queue @@ -383,6 +388,7 @@ class TestApplicationBuilder: assert app.concurrent_updates == concurrent_updates assert app.post_init is post_init assert app.post_shutdown is post_shutdown + assert app.post_stop is post_stop assert isinstance(app.bot.callback_data_cache, CallbackDataCache) updater = Updater(bot=bot, update_queue=update_queue)