mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Add Application.post_init
(#3078)
This commit is contained in:
parent
22419c0464
commit
63104ac0b3
4 changed files with 210 additions and 27 deletions
|
@ -179,6 +179,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
:meth:`add_error_handler`
|
||||
context_types (:class:`telegram.ext.ContextTypes`): Specifies the types used by this
|
||||
dispatcher for the ``context`` argument of handler and job callbacks.
|
||||
post_init (:term:`coroutine function`): Optional. A callback that will be executed by
|
||||
:meth:`Application.run_polling` and :meth:`Application.run_webhook` after initializing
|
||||
the application via :meth:`initialize`.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -209,6 +212,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
"handlers",
|
||||
"job_queue",
|
||||
"persistence",
|
||||
"post_init",
|
||||
"update_queue",
|
||||
"updater",
|
||||
"user_data",
|
||||
|
@ -224,6 +228,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
concurrent_updates: Union[bool, int],
|
||||
persistence: Optional[BasePersistence],
|
||||
context_types: ContextTypes[CCT, UD, CD, BD],
|
||||
post_init: Optional[
|
||||
Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]]
|
||||
],
|
||||
):
|
||||
if not was_called_by(
|
||||
inspect.currentframe(), Path(__file__).parent.resolve() / "_applicationbuilder.py"
|
||||
|
@ -240,6 +247,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
self.updater = updater
|
||||
self.handlers: Dict[int, List[BaseHandler]] = {}
|
||||
self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {}
|
||||
self.post_init = post_init
|
||||
|
||||
if isinstance(concurrent_updates, int) and concurrent_updates < 0:
|
||||
raise ValueError("`concurrent_updates` must be a non-negative integer!")
|
||||
|
@ -310,6 +318,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
* The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`.
|
||||
* The :attr:`persistence`, by loading persistent conversations and data.
|
||||
|
||||
Does *not* call :attr:`post_init` - that is only done by :meth:`run_polling` and
|
||||
:meth:`run_webhook`.
|
||||
|
||||
.. seealso::
|
||||
:meth:`shutdown`
|
||||
"""
|
||||
|
@ -558,6 +569,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
On unix, the app will also shut down on receiving the signals specified by
|
||||
:paramref:`stop_signals`.
|
||||
|
||||
If :attr:`post_init` is set, it will be called between :meth:`initialize` and
|
||||
:meth:`telegram.ext.Updater.start_polling`.
|
||||
|
||||
.. seealso::
|
||||
:meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown`
|
||||
:meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook`
|
||||
|
@ -663,6 +677,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
application. Else, the webhook will be started on
|
||||
``https://listen:port/url_path``. Also calls :meth:`telegram.Bot.set_webhook` as required.
|
||||
|
||||
If :attr:`post_init` is set, it will be called between :meth:`initialize` and
|
||||
:meth:`telegram.ext.Updater.start_webhook`.
|
||||
|
||||
.. seealso::
|
||||
:meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown`
|
||||
:meth:`telegram.ext.Updater.start_webhook`, :meth:`run_polling`
|
||||
|
@ -762,6 +779,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
|
|||
|
||||
try:
|
||||
loop.run_until_complete(self.initialize())
|
||||
if self.post_init:
|
||||
loop.run_until_complete(self.post_init(self))
|
||||
loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling
|
||||
loop.run_until_complete(self.start())
|
||||
loop.run_forever()
|
||||
|
|
|
@ -19,7 +19,18 @@
|
|||
"""This module contains the Builder classes for the telegram.ext module."""
|
||||
from asyncio import Queue
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Generic, Optional, Type, TypeVar, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from telegram._bot import Bot
|
||||
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
|
||||
|
@ -102,36 +113,37 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_token",
|
||||
"_base_url",
|
||||
"_base_file_url",
|
||||
"_connection_pool_size",
|
||||
"_proxy_url",
|
||||
"_connect_timeout",
|
||||
"_read_timeout",
|
||||
"_write_timeout",
|
||||
"_pool_timeout",
|
||||
"_request",
|
||||
"_get_updates_connection_pool_size",
|
||||
"_get_updates_proxy_url",
|
||||
"_get_updates_connect_timeout",
|
||||
"_get_updates_read_timeout",
|
||||
"_get_updates_write_timeout",
|
||||
"_get_updates_pool_timeout",
|
||||
"_get_updates_request",
|
||||
"_private_key",
|
||||
"_private_key_password",
|
||||
"_defaults",
|
||||
"_arbitrary_callback_data",
|
||||
"_bot",
|
||||
"_update_queue",
|
||||
"_job_queue",
|
||||
"_persistence",
|
||||
"_context_types",
|
||||
"_application_class",
|
||||
"_application_kwargs",
|
||||
"_arbitrary_callback_data",
|
||||
"_base_file_url",
|
||||
"_base_url",
|
||||
"_bot",
|
||||
"_concurrent_updates",
|
||||
"_connect_timeout",
|
||||
"_connection_pool_size",
|
||||
"_context_types",
|
||||
"_defaults",
|
||||
"_get_updates_connect_timeout",
|
||||
"_get_updates_connection_pool_size",
|
||||
"_get_updates_pool_timeout",
|
||||
"_get_updates_proxy_url",
|
||||
"_get_updates_read_timeout",
|
||||
"_get_updates_request",
|
||||
"_get_updates_write_timeout",
|
||||
"_job_queue",
|
||||
"_persistence",
|
||||
"_pool_timeout",
|
||||
"_post_init",
|
||||
"_private_key",
|
||||
"_private_key_password",
|
||||
"_proxy_url",
|
||||
"_read_timeout",
|
||||
"_request",
|
||||
"_token",
|
||||
"_update_queue",
|
||||
"_updater",
|
||||
"_write_timeout",
|
||||
)
|
||||
|
||||
def __init__(self: "InitApplicationBuilder"):
|
||||
|
@ -165,6 +177,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
self._application_kwargs: Dict[str, object] = {}
|
||||
self._concurrent_updates: DVInput[Union[int, bool]] = DEFAULT_FALSE
|
||||
self._updater: ODVInput[Updater] = DEFAULT_NONE
|
||||
self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
|
||||
|
||||
def _build_request(self, get_updates: bool) -> BaseRequest:
|
||||
prefix = "_get_updates_" if get_updates else "_"
|
||||
|
@ -257,6 +270,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
job_queue=job_queue,
|
||||
persistence=persistence,
|
||||
context_types=DefaultValue.get_value(self._context_types),
|
||||
post_init=self._post_init,
|
||||
**self._application_kwargs, # For custom Application subclasses
|
||||
)
|
||||
|
||||
|
@ -885,6 +899,40 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
self._updater = updater
|
||||
return self
|
||||
|
||||
def post_init(
|
||||
self: BuilderType, post_init: Callable[[Application], Coroutine[Any, Any, None]]
|
||||
) -> BuilderType:
|
||||
"""
|
||||
Sets a callback to be executed by :meth:`Application.run_polling` and
|
||||
:meth:`Application.run_webhook` *after* executing :meth:`Application.initialize` but
|
||||
*before* executing :meth:`Updater.start_polling` or :meth:`Updater.start_webhook`,
|
||||
respectively.
|
||||
|
||||
Tip:
|
||||
This can be used for custom startup logic that requires to await coroutines, e.g.
|
||||
setting up the bots commands via :meth:`~telegram.Bot.set_my_commands`.
|
||||
|
||||
Example:
|
||||
.. code::
|
||||
|
||||
async def post_init(application: Application) -> None:
|
||||
await application.bot.set_my_commands([('start', 'Starts the bot')])
|
||||
|
||||
application = Application.builder().token("TOKEN").post_init(post_init).build()
|
||||
|
||||
Args:
|
||||
post_init (:term:`coroutine function`): The custom callback. Must be a
|
||||
:term:`coroutine function` and must accept exactly one positional argument, which
|
||||
is the :class:`~telegram.ext.Application`::
|
||||
|
||||
async def post_init(application: Application) -> None:
|
||||
|
||||
Returns:
|
||||
:class:`ApplicationBuilder`: The same builder with the updated argument.
|
||||
"""
|
||||
self._post_init = post_init
|
||||
return self
|
||||
|
||||
|
||||
InitApplicationBuilder = ( # This is defined all the way down here so that its type is inferred
|
||||
ApplicationBuilder[ # by Pylance correctly.
|
||||
|
|
|
@ -129,6 +129,7 @@ class TestApplication:
|
|||
context_types=ContextTypes(),
|
||||
updater=updater,
|
||||
concurrent_updates=False,
|
||||
post_init=None,
|
||||
)
|
||||
assert len(recwarn) == 1
|
||||
assert (
|
||||
|
@ -147,6 +148,10 @@ class TestApplication:
|
|||
persistence = PicklePersistence("file_path")
|
||||
context_types = ContextTypes()
|
||||
updater = Updater(bot=bot, update_queue=update_queue)
|
||||
|
||||
async def post_init(application: Application) -> None:
|
||||
pass
|
||||
|
||||
app = Application(
|
||||
bot=bot,
|
||||
update_queue=update_queue,
|
||||
|
@ -155,6 +160,7 @@ class TestApplication:
|
|||
context_types=context_types,
|
||||
updater=updater,
|
||||
concurrent_updates=concurrent_updates,
|
||||
post_init=post_init,
|
||||
)
|
||||
assert app.bot is bot
|
||||
assert app.update_queue is update_queue
|
||||
|
@ -165,6 +171,7 @@ class TestApplication:
|
|||
assert app.update_queue is updater.update_queue
|
||||
assert app.bot is updater.bot
|
||||
assert app.concurrent_updates == expected
|
||||
assert app.post_init is post_init
|
||||
|
||||
# These should be done by the builder
|
||||
assert app.persistence.bot is None
|
||||
|
@ -184,6 +191,7 @@ class TestApplication:
|
|||
context_types=context_types,
|
||||
updater=updater,
|
||||
concurrent_updates=-1,
|
||||
post_init=None,
|
||||
)
|
||||
|
||||
def test_custom_context_init(self, bot):
|
||||
|
@ -1383,6 +1391,48 @@ class TestApplication:
|
|||
for key, value in assertions.items():
|
||||
assert value, f"assertion '{key}' failed!"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
)
|
||||
def test_run_polling_post_init(self, bot, monkeypatch):
|
||||
events = []
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
# This makes sure that other coroutines have a chance of running as well
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
def thread_target():
|
||||
waited = 0
|
||||
while not app.running:
|
||||
time.sleep(0.05)
|
||||
waited += 0.05
|
||||
if waited > 5:
|
||||
pytest.fail("App apparently won't start")
|
||||
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
async def post_init(app: Application) -> None:
|
||||
events.append("post_init")
|
||||
|
||||
app = Application.builder().token(bot.token).post_init(post_init).build()
|
||||
monkeypatch.setattr(app.bot, "get_updates", get_updates)
|
||||
monkeypatch.setattr(
|
||||
app, "initialize", call_after(app.initialize, lambda _: events.append("init"))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app.updater,
|
||||
"start_polling",
|
||||
call_after(app.updater.start_polling, lambda _: events.append("start_polling")),
|
||||
)
|
||||
|
||||
thread = Thread(target=thread_target)
|
||||
thread.start()
|
||||
app.run_polling(drop_pending_updates=True, close_loop=False)
|
||||
thread.join()
|
||||
assert events == ["init", "post_init", "start_polling"], "Wrong order of events detected!"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
|
@ -1511,6 +1561,65 @@ class TestApplication:
|
|||
for key, value in assertions.items():
|
||||
assert value, f"assertion '{key}' failed!"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
)
|
||||
def test_run_webhook_post_init(self, bot, monkeypatch):
|
||||
events = []
|
||||
|
||||
async def delete_webhook(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def set_webhook(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def get_updates(*args, **kwargs):
|
||||
# This makes sure that other coroutines have a chance of running as well
|
||||
await asyncio.sleep(0)
|
||||
return []
|
||||
|
||||
def thread_target():
|
||||
waited = 0
|
||||
while not app.running:
|
||||
time.sleep(0.05)
|
||||
waited += 0.05
|
||||
if waited > 5:
|
||||
pytest.fail("App apparently won't start")
|
||||
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
async def post_init(app: Application) -> None:
|
||||
events.append("post_init")
|
||||
|
||||
app = Application.builder().token(bot.token).post_init(post_init).build()
|
||||
monkeypatch.setattr(app.bot, "set_webhook", set_webhook)
|
||||
monkeypatch.setattr(app.bot, "delete_webhook", delete_webhook)
|
||||
monkeypatch.setattr(
|
||||
app, "initialize", call_after(app.initialize, lambda _: events.append("init"))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app.updater,
|
||||
"start_webhook",
|
||||
call_after(app.updater.start_webhook, lambda _: events.append("start_webhook")),
|
||||
)
|
||||
|
||||
thread = Thread(target=thread_target)
|
||||
thread.start()
|
||||
|
||||
ip = "127.0.0.1"
|
||||
port = randrange(1024, 49152)
|
||||
|
||||
app.run_webhook(
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
drop_pending_updates=True,
|
||||
close_loop=False,
|
||||
)
|
||||
thread.join()
|
||||
assert events == ["init", "post_init", "start_webhook"], "Wrong order of events detected!"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows",
|
||||
reason="Can't send signals without stopping whole process on windows",
|
||||
|
|
|
@ -106,6 +106,7 @@ class TestApplicationBuilder:
|
|||
assert app.job_queue.application is app
|
||||
|
||||
assert app.persistence is None
|
||||
assert app.post_init is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method, description", _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS]
|
||||
|
@ -317,6 +318,10 @@ class TestApplicationBuilder:
|
|||
update_queue = asyncio.Queue()
|
||||
context_types = ContextTypes()
|
||||
concurrent_updates = 123
|
||||
|
||||
async def post_init(app: Application) -> None:
|
||||
pass
|
||||
|
||||
app = (
|
||||
builder.token(bot.token)
|
||||
.job_queue(job_queue)
|
||||
|
@ -324,6 +329,7 @@ class TestApplicationBuilder:
|
|||
.update_queue(update_queue)
|
||||
.context_types(context_types)
|
||||
.concurrent_updates(concurrent_updates)
|
||||
.post_init(post_init)
|
||||
).build()
|
||||
assert app.job_queue is job_queue
|
||||
assert app.job_queue.application is app
|
||||
|
@ -334,6 +340,7 @@ class TestApplicationBuilder:
|
|||
assert app.updater.bot is app.bot
|
||||
assert app.context_types is context_types
|
||||
assert app.concurrent_updates == concurrent_updates
|
||||
assert app.post_init is post_init
|
||||
|
||||
updater = Updater(bot=bot, update_queue=update_queue)
|
||||
app = ApplicationBuilder().updater(updater).build()
|
||||
|
|
Loading…
Reference in a new issue