From 63104ac0b3e6bfb60f3336762719dad2cc59c1ce Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 8 Jun 2022 07:44:22 +0200 Subject: [PATCH] Add `Application.post_init` (#3078) --- telegram/ext/_application.py | 19 +++++ telegram/ext/_applicationbuilder.py | 102 +++++++++++++++++++------- tests/test_application.py | 109 ++++++++++++++++++++++++++++ tests/test_applicationbuilder.py | 7 ++ 4 files changed, 210 insertions(+), 27 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 63320ec69..58a66098c 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -179,6 +179,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) :meth:`add_error_handler` context_types (:class:`telegram.ext.ContextTypes`): Specifies the types used by this dispatcher for the ``context`` argument of handler and job callbacks. + post_init (:term:`coroutine function`): Optional. A callback that will be executed by + :meth:`Application.run_polling` and :meth:`Application.run_webhook` after initializing + the application via :meth:`initialize`. """ @@ -209,6 +212,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) "handlers", "job_queue", "persistence", + "post_init", "update_queue", "updater", "user_data", @@ -224,6 +228,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) concurrent_updates: Union[bool, int], persistence: Optional[BasePersistence], context_types: ContextTypes[CCT, UD, CD, BD], + post_init: Optional[ + Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]] + ], ): if not was_called_by( inspect.currentframe(), Path(__file__).parent.resolve() / "_applicationbuilder.py" @@ -240,6 +247,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) self.updater = updater self.handlers: Dict[int, List[BaseHandler]] = {} self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {} + self.post_init = post_init if isinstance(concurrent_updates, int) and concurrent_updates < 0: raise ValueError("`concurrent_updates` must be a non-negative integer!") @@ -310,6 +318,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) * The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`. * The :attr:`persistence`, by loading persistent conversations and data. + Does *not* call :attr:`post_init` - that is only done by :meth:`run_polling` and + :meth:`run_webhook`. + .. seealso:: :meth:`shutdown` """ @@ -558,6 +569,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) On unix, the app will also shut down on receiving the signals specified by :paramref:`stop_signals`. + If :attr:`post_init` is set, it will be called between :meth:`initialize` and + :meth:`telegram.ext.Updater.start_polling`. + .. seealso:: :meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown` :meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook` @@ -663,6 +677,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) application. Else, the webhook will be started on ``https://listen:port/url_path``. Also calls :meth:`telegram.Bot.set_webhook` as required. + If :attr:`post_init` is set, it will be called between :meth:`initialize` and + :meth:`telegram.ext.Updater.start_webhook`. + .. seealso:: :meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown` :meth:`telegram.ext.Updater.start_webhook`, :meth:`run_polling` @@ -762,6 +779,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) try: loop.run_until_complete(self.initialize()) + if self.post_init: + loop.run_until_complete(self.post_init(self)) loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling loop.run_until_complete(self.start()) loop.run_forever() diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index aaccb4766..2d3f20490 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -19,7 +19,18 @@ """This module contains the Builder classes for the telegram.ext module.""" from asyncio import Queue from pathlib import Path -from typing import TYPE_CHECKING, Dict, Generic, Optional, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Generic, + Optional, + Type, + TypeVar, + Union, +) from telegram._bot import Bot from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue @@ -102,36 +113,37 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): """ __slots__ = ( - "_token", - "_base_url", - "_base_file_url", - "_connection_pool_size", - "_proxy_url", - "_connect_timeout", - "_read_timeout", - "_write_timeout", - "_pool_timeout", - "_request", - "_get_updates_connection_pool_size", - "_get_updates_proxy_url", - "_get_updates_connect_timeout", - "_get_updates_read_timeout", - "_get_updates_write_timeout", - "_get_updates_pool_timeout", - "_get_updates_request", - "_private_key", - "_private_key_password", - "_defaults", - "_arbitrary_callback_data", - "_bot", - "_update_queue", - "_job_queue", - "_persistence", - "_context_types", "_application_class", "_application_kwargs", + "_arbitrary_callback_data", + "_base_file_url", + "_base_url", + "_bot", "_concurrent_updates", + "_connect_timeout", + "_connection_pool_size", + "_context_types", + "_defaults", + "_get_updates_connect_timeout", + "_get_updates_connection_pool_size", + "_get_updates_pool_timeout", + "_get_updates_proxy_url", + "_get_updates_read_timeout", + "_get_updates_request", + "_get_updates_write_timeout", + "_job_queue", + "_persistence", + "_pool_timeout", + "_post_init", + "_private_key", + "_private_key_password", + "_proxy_url", + "_read_timeout", + "_request", + "_token", + "_update_queue", "_updater", + "_write_timeout", ) def __init__(self: "InitApplicationBuilder"): @@ -165,6 +177,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._application_kwargs: Dict[str, object] = {} self._concurrent_updates: DVInput[Union[int, bool]] = DEFAULT_FALSE self._updater: ODVInput[Updater] = DEFAULT_NONE + self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None def _build_request(self, get_updates: bool) -> BaseRequest: prefix = "_get_updates_" if get_updates else "_" @@ -257,6 +270,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): job_queue=job_queue, persistence=persistence, context_types=DefaultValue.get_value(self._context_types), + post_init=self._post_init, **self._application_kwargs, # For custom Application subclasses ) @@ -885,6 +899,40 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._updater = updater return self + def post_init( + self: BuilderType, post_init: Callable[[Application], Coroutine[Any, Any, None]] + ) -> BuilderType: + """ + Sets a callback to be executed by :meth:`Application.run_polling` and + :meth:`Application.run_webhook` *after* executing :meth:`Application.initialize` but + *before* executing :meth:`Updater.start_polling` or :meth:`Updater.start_webhook`, + respectively. + + Tip: + This can be used for custom startup logic that requires to await coroutines, e.g. + setting up the bots commands via :meth:`~telegram.Bot.set_my_commands`. + + Example: + .. code:: + + async def post_init(application: Application) -> None: + await application.bot.set_my_commands([('start', 'Starts the bot')]) + + application = Application.builder().token("TOKEN").post_init(post_init).build() + + Args: + post_init (:term:`coroutine function`): The custom callback. Must be a + :term:`coroutine function` and must accept exactly one positional argument, which + is the :class:`~telegram.ext.Application`:: + + async def post_init(application: Application) -> None: + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ + self._post_init = post_init + return self + InitApplicationBuilder = ( # This is defined all the way down here so that its type is inferred ApplicationBuilder[ # by Pylance correctly. diff --git a/tests/test_application.py b/tests/test_application.py index 8070bafb4..13c6f0738 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -129,6 +129,7 @@ class TestApplication: context_types=ContextTypes(), updater=updater, concurrent_updates=False, + post_init=None, ) assert len(recwarn) == 1 assert ( @@ -147,6 +148,10 @@ class TestApplication: persistence = PicklePersistence("file_path") context_types = ContextTypes() updater = Updater(bot=bot, update_queue=update_queue) + + async def post_init(application: Application) -> None: + pass + app = Application( bot=bot, update_queue=update_queue, @@ -155,6 +160,7 @@ class TestApplication: context_types=context_types, updater=updater, concurrent_updates=concurrent_updates, + post_init=post_init, ) assert app.bot is bot assert app.update_queue is update_queue @@ -165,6 +171,7 @@ class TestApplication: assert app.update_queue is updater.update_queue assert app.bot is updater.bot assert app.concurrent_updates == expected + assert app.post_init is post_init # These should be done by the builder assert app.persistence.bot is None @@ -184,6 +191,7 @@ class TestApplication: context_types=context_types, updater=updater, concurrent_updates=-1, + post_init=None, ) def test_custom_context_init(self, bot): @@ -1383,6 +1391,48 @@ class TestApplication: for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + @pytest.mark.skipif( + platform.system() == "Windows", + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_polling_post_init(self, bot, monkeypatch): + events = [] + + async def get_updates(*args, **kwargs): + # This makes sure that other coroutines have a chance of running as well + await asyncio.sleep(0) + return [] + + def thread_target(): + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + + os.kill(os.getpid(), signal.SIGINT) + + async def post_init(app: Application) -> None: + events.append("post_init") + + app = Application.builder().token(bot.token).post_init(post_init).build() + monkeypatch.setattr(app.bot, "get_updates", get_updates) + monkeypatch.setattr( + app, "initialize", call_after(app.initialize, lambda _: events.append("init")) + ) + monkeypatch.setattr( + app.updater, + "start_polling", + call_after(app.updater.start_polling, lambda _: events.append("start_polling")), + ) + + thread = Thread(target=thread_target) + thread.start() + app.run_polling(drop_pending_updates=True, close_loop=False) + thread.join() + assert events == ["init", "post_init", "start_polling"], "Wrong order of events detected!" + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", @@ -1511,6 +1561,65 @@ class TestApplication: for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + @pytest.mark.skipif( + platform.system() == "Windows", + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_webhook_post_init(self, bot, monkeypatch): + events = [] + + async def delete_webhook(*args, **kwargs): + return True + + async def set_webhook(*args, **kwargs): + return True + + async def get_updates(*args, **kwargs): + # This makes sure that other coroutines have a chance of running as well + await asyncio.sleep(0) + return [] + + def thread_target(): + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + + os.kill(os.getpid(), signal.SIGINT) + + async def post_init(app: Application) -> None: + events.append("post_init") + + app = Application.builder().token(bot.token).post_init(post_init).build() + monkeypatch.setattr(app.bot, "set_webhook", set_webhook) + monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook) + monkeypatch.setattr( + app, "initialize", call_after(app.initialize, lambda _: events.append("init")) + ) + monkeypatch.setattr( + app.updater, + "start_webhook", + call_after(app.updater.start_webhook, lambda _: events.append("start_webhook")), + ) + + thread = Thread(target=thread_target) + thread.start() + + ip = "127.0.0.1" + port = randrange(1024, 49152) + + app.run_webhook( + ip_address=ip, + port=port, + url_path="TOKEN", + drop_pending_updates=True, + close_loop=False, + ) + thread.join() + assert events == ["init", "post_init", "start_webhook"], "Wrong order of events detected!" + @pytest.mark.skipif( platform.system() == "Windows", reason="Can't send signals without stopping whole process on windows", diff --git a/tests/test_applicationbuilder.py b/tests/test_applicationbuilder.py index ad4a4548a..a2e0e5361 100644 --- a/tests/test_applicationbuilder.py +++ b/tests/test_applicationbuilder.py @@ -106,6 +106,7 @@ class TestApplicationBuilder: assert app.job_queue.application is app assert app.persistence is None + assert app.post_init is None @pytest.mark.parametrize( "method, description", _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS] @@ -317,6 +318,10 @@ class TestApplicationBuilder: update_queue = asyncio.Queue() context_types = ContextTypes() concurrent_updates = 123 + + async def post_init(app: Application) -> None: + pass + app = ( builder.token(bot.token) .job_queue(job_queue) @@ -324,6 +329,7 @@ class TestApplicationBuilder: .update_queue(update_queue) .context_types(context_types) .concurrent_updates(concurrent_updates) + .post_init(post_init) ).build() assert app.job_queue is job_queue assert app.job_queue.application is app @@ -334,6 +340,7 @@ class TestApplicationBuilder: assert app.updater.bot is app.bot assert app.context_types is context_types assert app.concurrent_updates == concurrent_updates + assert app.post_init is post_init updater = Updater(bot=bot, update_queue=update_queue) app = ApplicationBuilder().updater(updater).build()