mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-17 04:39:55 +01:00
Introduce BaseUpdateProcessor
for Customized Concurrent Handling of Updates (#3654)
Co-authored-by: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
parent
4c8d7332db
commit
bf54599618
10 changed files with 496 additions and 72 deletions
6
docs/source/telegram.ext.baseupdateprocessor.rst
Normal file
6
docs/source/telegram.ext.baseupdateprocessor.rst
Normal file
|
@ -0,0 +1,6 @@
|
|||
BaseUpdateProcessor
|
||||
===================
|
||||
|
||||
.. autoclass:: telegram.ext.BaseUpdateProcessor
|
||||
:members:
|
||||
:show-inheritance:
|
|
@ -9,12 +9,14 @@ telegram.ext package
|
|||
telegram.ext.application
|
||||
telegram.ext.applicationbuilder
|
||||
telegram.ext.applicationhandlerstop
|
||||
telegram.ext.baseupdateprocessor
|
||||
telegram.ext.callbackcontext
|
||||
telegram.ext.contexttypes
|
||||
telegram.ext.defaults
|
||||
telegram.ext.extbot
|
||||
telegram.ext.job
|
||||
telegram.ext.jobqueue
|
||||
telegram.ext.simpleupdateprocessor
|
||||
telegram.ext.updater
|
||||
telegram.ext.handlers-tree.rst
|
||||
telegram.ext.persistence-tree.rst
|
||||
|
|
6
docs/source/telegram.ext.simpleupdateprocessor.rst
Normal file
6
docs/source/telegram.ext.simpleupdateprocessor.rst
Normal file
|
@ -0,0 +1,6 @@
|
|||
SimpleUpdateProcessor
|
||||
=====================
|
||||
|
||||
.. autoclass:: telegram.ext.SimpleUpdateProcessor
|
||||
:members:
|
||||
:show-inheritance:
|
|
@ -26,6 +26,7 @@ __all__ = (
|
|||
"BaseHandler",
|
||||
"BasePersistence",
|
||||
"BaseRateLimiter",
|
||||
"BaseUpdateProcessor",
|
||||
"CallbackContext",
|
||||
"CallbackDataCache",
|
||||
"CallbackQueryHandler",
|
||||
|
@ -51,6 +52,7 @@ __all__ = (
|
|||
"PreCheckoutQueryHandler",
|
||||
"PrefixHandler",
|
||||
"ShippingQueryHandler",
|
||||
"SimpleUpdateProcessor",
|
||||
"StringCommandHandler",
|
||||
"StringRegexHandler",
|
||||
"TypeHandler",
|
||||
|
@ -63,6 +65,7 @@ from ._application import Application, ApplicationHandlerStop
|
|||
from ._applicationbuilder import ApplicationBuilder
|
||||
from ._basepersistence import BasePersistence, PersistenceInput
|
||||
from ._baseratelimiter import BaseRateLimiter
|
||||
from ._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
|
||||
from ._callbackcontext import CallbackContext
|
||||
from ._callbackdatacache import CallbackDataCache, InvalidCallbackData
|
||||
from ._callbackqueryhandler import CallbackQueryHandler
|
||||
|
|
|
@ -57,6 +57,7 @@ from telegram._utils.types import SCT, DVType, ODVInput
|
|||
from telegram._utils.warnings import warn
|
||||
from telegram.error import TelegramError
|
||||
from telegram.ext._basepersistence import BasePersistence
|
||||
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor
|
||||
from telegram.ext._contexttypes import ContextTypes
|
||||
from telegram.ext._extbot import ExtBot
|
||||
from telegram.ext._handler import BaseHandler
|
||||
|
@ -228,12 +229,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
"_chat_data",
|
||||
"_chat_ids_to_be_deleted_in_persistence",
|
||||
"_chat_ids_to_be_updated_in_persistence",
|
||||
"_concurrent_updates",
|
||||
"_concurrent_updates_sem",
|
||||
"_conversation_handler_conversations",
|
||||
"_initialized",
|
||||
"_job_queue",
|
||||
"_running",
|
||||
"_update_processor",
|
||||
"_user_data",
|
||||
"_user_ids_to_be_deleted_in_persistence",
|
||||
"_user_ids_to_be_updated_in_persistence",
|
||||
|
@ -259,7 +259,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
update_queue: "asyncio.Queue[object]",
|
||||
updater: Optional[Updater],
|
||||
job_queue: JQ,
|
||||
concurrent_updates: Union[bool, int],
|
||||
update_processor: "BaseUpdateProcessor",
|
||||
persistence: Optional[BasePersistence[UD, CD, BD]],
|
||||
context_types: ContextTypes[CCT, UD, CD, BD],
|
||||
post_init: Optional[
|
||||
|
@ -297,14 +297,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
self.post_stop: Optional[
|
||||
Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]]
|
||||
] = post_stop
|
||||
|
||||
if isinstance(concurrent_updates, int) and concurrent_updates < 0:
|
||||
raise ValueError("`concurrent_updates` must be a non-negative integer!")
|
||||
if concurrent_updates is True:
|
||||
concurrent_updates = 256
|
||||
self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1)
|
||||
self._concurrent_updates: int = concurrent_updates or 0
|
||||
|
||||
self._update_processor = update_processor
|
||||
self.bot_data: BD = self.context_types.bot_data()
|
||||
self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data)
|
||||
self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data)
|
||||
|
@ -359,9 +352,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
""":obj:`int`: The number of concurrent updates that will be processed in parallel. A
|
||||
value of ``0`` indicates updates are *not* being processed concurrently.
|
||||
|
||||
.. versionchanged:: NEXT.VERSION
|
||||
This is now just a shortcut to :attr:`update_processor.max_concurrent_updates
|
||||
<telegram.ext.BaseUpdateProcessor.max_concurrent_updates>`.
|
||||
|
||||
.. seealso:: :wiki:`Concurrency`
|
||||
"""
|
||||
return self._concurrent_updates
|
||||
return self._update_processor.max_concurrent_updates
|
||||
|
||||
@property
|
||||
def job_queue(self) -> Optional["JobQueue[CCT]"]:
|
||||
|
@ -379,12 +376,25 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
)
|
||||
return self._job_queue
|
||||
|
||||
@property
|
||||
def update_processor(self) -> "BaseUpdateProcessor":
|
||||
""":class:`telegram.ext.BaseUpdateProcessor`: The update processor used by this
|
||||
application.
|
||||
|
||||
.. seealso:: :wiki:`Concurrency`
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
return self._update_processor
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initializes the Application by initializing:
|
||||
|
||||
* The :attr:`bot`, by calling :meth:`telegram.Bot.initialize`.
|
||||
* The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`.
|
||||
* The :attr:`persistence`, by loading persistent conversations and data.
|
||||
* The :attr:`update_processor` by calling
|
||||
:meth:`telegram.ext.BaseUpdateProcessor.initialize`.
|
||||
|
||||
Does *not* call :attr:`post_init` - that is only done by :meth:`run_polling` and
|
||||
:meth:`run_webhook`.
|
||||
|
@ -397,6 +407,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
return
|
||||
|
||||
await self.bot.initialize()
|
||||
await self._update_processor.initialize()
|
||||
|
||||
if self.updater:
|
||||
await self.updater.initialize()
|
||||
|
||||
|
@ -429,6 +441,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
* :attr:`updater` by calling :meth:`telegram.ext.Updater.shutdown`
|
||||
* :attr:`persistence` by calling :meth:`update_persistence` and
|
||||
:meth:`BasePersistence.flush`
|
||||
* :attr:`update_processor` by calling :meth:`telegram.ext.BaseUpdateProcessor.shutdown`
|
||||
|
||||
Does *not* call :attr:`post_shutdown` - that is only done by :meth:`run_polling` and
|
||||
:meth:`run_webhook`.
|
||||
|
@ -447,6 +460,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
return
|
||||
|
||||
await self.bot.shutdown()
|
||||
await self._update_processor.shutdown()
|
||||
|
||||
if self.updater:
|
||||
await self.updater.shutdown()
|
||||
|
||||
|
@ -1060,11 +1075,15 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
|
||||
_LOGGER.debug("Processing update %s", update)
|
||||
|
||||
if self._concurrent_updates:
|
||||
if self._update_processor.max_concurrent_updates > 1:
|
||||
# We don't await the below because it has to be run concurrently
|
||||
self.create_task(self.__process_update_wrapper(update), update=update)
|
||||
self.create_task(
|
||||
self.__process_update_wrapper(update),
|
||||
update=update,
|
||||
)
|
||||
else:
|
||||
await self.__process_update_wrapper(update)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# This may happen if the application is manually run via application.start() and
|
||||
# then a KeyboardInterrupt is sent. We must prevent this loop to die since
|
||||
|
@ -1075,9 +1094,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
)
|
||||
|
||||
async def __process_update_wrapper(self, update: object) -> None:
|
||||
async with self._concurrent_updates_sem:
|
||||
await self.process_update(update)
|
||||
self.update_queue.task_done()
|
||||
await self._update_processor.process_update(update, self.process_update(update))
|
||||
self.update_queue.task_done()
|
||||
|
||||
async def process_update(self, update: object) -> None:
|
||||
"""Processes a single update and marks the update to be updated by the persistence later.
|
||||
|
|
|
@ -36,6 +36,7 @@ from telegram._bot import Bot
|
|||
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
|
||||
from telegram._utils.types import DVInput, DVType, FilePathInput, ODVInput
|
||||
from telegram.ext._application import Application
|
||||
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
|
||||
from telegram.ext._contexttypes import ContextTypes
|
||||
from telegram.ext._extbot import ExtBot
|
||||
from telegram.ext._jobqueue import JobQueue
|
||||
|
@ -127,7 +128,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
"_base_file_url",
|
||||
"_base_url",
|
||||
"_bot",
|
||||
"_concurrent_updates",
|
||||
"_update_processor",
|
||||
"_connect_timeout",
|
||||
"_connection_pool_size",
|
||||
"_context_types",
|
||||
|
@ -198,7 +199,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
self._context_types: DVType[ContextTypes] = DefaultValue(ContextTypes())
|
||||
self._application_class: DVType[Type[Application]] = DefaultValue(Application)
|
||||
self._application_kwargs: Dict[str, object] = {}
|
||||
self._concurrent_updates: Union[int, DefaultValue[bool]] = DEFAULT_FALSE
|
||||
self._update_processor: "BaseUpdateProcessor" = SimpleUpdateProcessor(
|
||||
max_concurrent_updates=1
|
||||
)
|
||||
self._updater: ODVInput[Updater] = DEFAULT_NONE
|
||||
self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
|
||||
self._post_shutdown: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
|
||||
|
@ -306,7 +309,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
bot=bot,
|
||||
update_queue=update_queue,
|
||||
updater=updater,
|
||||
concurrent_updates=DefaultValue.get_value(self._concurrent_updates),
|
||||
update_processor=self._update_processor,
|
||||
job_queue=job_queue,
|
||||
persistence=persistence,
|
||||
context_types=DefaultValue.get_value(self._context_types),
|
||||
|
@ -902,7 +905,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
self._update_queue = update_queue
|
||||
return self
|
||||
|
||||
def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) -> BuilderType:
|
||||
def concurrent_updates(
|
||||
self: BuilderType, concurrent_updates: Union[bool, int, "BaseUpdateProcessor"]
|
||||
) -> BuilderType:
|
||||
"""Specifies if and how many updates may be processed concurrently instead of one by one.
|
||||
If not called, updates will be processed one by one.
|
||||
|
||||
|
@ -917,14 +922,34 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
|
|||
.. seealso:: :attr:`telegram.ext.Application.concurrent_updates`
|
||||
|
||||
Args:
|
||||
concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for
|
||||
``256`` updates to be processed concurrently. Pass an integer to specify a
|
||||
different number of updates that may be processed concurrently.
|
||||
concurrent_updates (:obj:`bool` | :obj:`int` | :class:`BaseUpdateProcessor`): Passing
|
||||
:obj:`True` will allow for ``256`` updates to be processed concurrently using
|
||||
:class:`telegram.ext.SimpleUpdateProcessor`. Pass an integer to specify a different
|
||||
number of updates that may be processed concurrently. Pass an instance of
|
||||
:class:`telegram.ext.BaseUpdateProcessor` to use that instance for handling updates
|
||||
concurrently.
|
||||
|
||||
.. versionchanged:: NEXT.VERSION
|
||||
Now accepts :class:`BaseUpdateProcessor` instances.
|
||||
|
||||
Returns:
|
||||
:class:`ApplicationBuilder`: The same builder with the updated argument.
|
||||
"""
|
||||
self._concurrent_updates = concurrent_updates
|
||||
# Check if concurrent updates is bool and convert to integer
|
||||
if concurrent_updates is True:
|
||||
concurrent_updates = 256
|
||||
elif concurrent_updates is False:
|
||||
concurrent_updates = 1
|
||||
|
||||
# If `concurrent_updates` is an integer, create a `SimpleUpdateProcessor`
|
||||
# instance with that integer value; otherwise, raise an error if the value
|
||||
# is negative
|
||||
if isinstance(concurrent_updates, int):
|
||||
concurrent_updates = SimpleUpdateProcessor(concurrent_updates)
|
||||
|
||||
# Assign default value of concurrent_updates if it is instance of
|
||||
# `BaseUpdateProcessor`
|
||||
self._update_processor: BaseUpdateProcessor = concurrent_updates # type: ignore[no-redef]
|
||||
return self
|
||||
|
||||
def job_queue(
|
||||
|
|
154
telegram/ext/_baseupdateprocessor.py
Normal file
154
telegram/ext/_baseupdateprocessor.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# A library that provides a Python interface to the Telegram Bot API
|
||||
# Copyright (C) 2015-2023
|
||||
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""This module contains the BaseProcessor class."""
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import BoundedSemaphore
|
||||
from types import TracebackType
|
||||
from typing import Any, Awaitable, Optional, Type
|
||||
|
||||
|
||||
class BaseUpdateProcessor(ABC):
|
||||
"""An abstract base class for update processors. You can use this class to implement
|
||||
your own update processor.
|
||||
|
||||
.. seealso:: :wiki:`Concurrency`
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
|
||||
Args:
|
||||
max_concurrent_updates (:obj:`int`): The maximum number of updates to be processed
|
||||
concurrently. If this number is exceeded, new updates will be queued until the number
|
||||
of currently processed updates decreases.
|
||||
|
||||
Raises:
|
||||
:exc:`ValueError`: If :paramref:`max_concurrent_updates` is a non-positive integer.
|
||||
"""
|
||||
|
||||
__slots__ = ("_max_concurrent_updates", "_semaphore")
|
||||
|
||||
def __init__(self, max_concurrent_updates: int):
|
||||
self._max_concurrent_updates = max_concurrent_updates
|
||||
if self.max_concurrent_updates < 1:
|
||||
raise ValueError("`max_concurrent_updates` must be a positive integer!")
|
||||
self._semaphore = BoundedSemaphore(self.max_concurrent_updates)
|
||||
|
||||
@property
|
||||
def max_concurrent_updates(self) -> int:
|
||||
""":obj:`int`: The maximum number of updates that can be processed concurrently."""
|
||||
return self._max_concurrent_updates
|
||||
|
||||
@abstractmethod
|
||||
async def do_process_update(
|
||||
self,
|
||||
update: object,
|
||||
coroutine: "Awaitable[Any]",
|
||||
) -> None:
|
||||
"""Custom implementation of how to process an update. Must be implemented by a subclass.
|
||||
|
||||
Warning:
|
||||
This method will be called by :meth:`process_update`. It should *not* be called
|
||||
manually.
|
||||
|
||||
Args:
|
||||
update (:obj:`object`): The update to be processed.
|
||||
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
|
||||
update.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initializes the processor so resources can be allocated. Must be implemented by a
|
||||
subclass.
|
||||
|
||||
.. seealso::
|
||||
:meth:`shutdown`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the processor so resources can be freed. Must be implemented by a subclass.
|
||||
|
||||
.. seealso::
|
||||
:meth:`initialize`
|
||||
"""
|
||||
|
||||
async def process_update(
|
||||
self,
|
||||
update: object,
|
||||
coroutine: "Awaitable[Any]",
|
||||
) -> None:
|
||||
"""Calls :meth:`do_process_update` with a semaphore to limit the number of concurrent
|
||||
updates.
|
||||
|
||||
Args:
|
||||
update (:obj:`object`): The update to be processed.
|
||||
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
|
||||
update.
|
||||
"""
|
||||
async with self._semaphore:
|
||||
await self.do_process_update(update, coroutine)
|
||||
|
||||
async def __aenter__(self) -> "BaseUpdateProcessor":
|
||||
"""Simple context manager which initializes the Processor."""
|
||||
try:
|
||||
await self.initialize()
|
||||
return self
|
||||
except Exception as exc:
|
||||
await self.shutdown()
|
||||
raise exc
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""Shutdown the Processor from the context manager."""
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
class SimpleUpdateProcessor(BaseUpdateProcessor):
|
||||
"""Instance of :class:`telegram.ext.BaseUpdateProcessor` that immediately awaits the
|
||||
coroutine, i.e. does not apply any additional processing. This is used by default when
|
||||
:attr:`telegram.ext.ApplicationBuilder.concurrent_updates` is :obj:`int`.
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
async def do_process_update(
|
||||
self,
|
||||
update: object,
|
||||
coroutine: "Awaitable[Any]",
|
||||
) -> None:
|
||||
"""Immediately awaits the coroutine, i.e. does not apply any additional processing.
|
||||
|
||||
Args:
|
||||
update (:obj:`object`): The update to be processed.
|
||||
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
|
||||
update.
|
||||
"""
|
||||
await coroutine
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Does nothing."""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Does nothing."""
|
|
@ -49,6 +49,7 @@ from telegram.ext import (
|
|||
JobQueue,
|
||||
MessageHandler,
|
||||
PicklePersistence,
|
||||
SimpleUpdateProcessor,
|
||||
TypeHandler,
|
||||
Updater,
|
||||
filters,
|
||||
|
@ -134,7 +135,7 @@ class TestApplication:
|
|||
persistence=None,
|
||||
context_types=ContextTypes(),
|
||||
updater=updater,
|
||||
concurrent_updates=False,
|
||||
update_processor=False,
|
||||
post_init=None,
|
||||
post_shutdown=None,
|
||||
post_stop=None,
|
||||
|
@ -147,15 +148,13 @@ class TestApplication:
|
|||
assert recwarn[0].category is PTBUserWarning
|
||||
assert recwarn[0].filename == __file__, "stacklevel is incorrect!"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("concurrent_updates", "expected"), [(0, 0), (4, 4), (False, 0), (True, 256)]
|
||||
)
|
||||
@pytest.mark.filterwarnings("ignore: `Application` instances should")
|
||||
def test_init(self, one_time_bot, concurrent_updates, expected):
|
||||
def test_init(self, one_time_bot):
|
||||
update_queue = asyncio.Queue()
|
||||
job_queue = JobQueue()
|
||||
persistence = PicklePersistence("file_path")
|
||||
context_types = ContextTypes()
|
||||
update_processor = SimpleUpdateProcessor(1)
|
||||
updater = Updater(bot=one_time_bot, update_queue=update_queue)
|
||||
|
||||
async def post_init(application: Application) -> None:
|
||||
|
@ -174,7 +173,7 @@ class TestApplication:
|
|||
persistence=persistence,
|
||||
context_types=context_types,
|
||||
updater=updater,
|
||||
concurrent_updates=concurrent_updates,
|
||||
update_processor=update_processor,
|
||||
post_init=post_init,
|
||||
post_shutdown=post_shutdown,
|
||||
post_stop=post_stop,
|
||||
|
@ -187,7 +186,7 @@ class TestApplication:
|
|||
assert app.updater is updater
|
||||
assert app.update_queue is updater.update_queue
|
||||
assert app.bot is updater.bot
|
||||
assert app.concurrent_updates == expected
|
||||
assert app.update_processor is update_processor
|
||||
assert app.post_init is post_init
|
||||
assert app.post_shutdown is post_shutdown
|
||||
assert app.post_stop is post_stop
|
||||
|
@ -201,20 +200,6 @@ class TestApplication:
|
|||
assert isinstance(app.chat_data[1], dict)
|
||||
assert isinstance(app.user_data[1], dict)
|
||||
|
||||
with pytest.raises(ValueError, match="must be a non-negative"):
|
||||
Application(
|
||||
bot=one_time_bot,
|
||||
update_queue=update_queue,
|
||||
job_queue=job_queue,
|
||||
persistence=persistence,
|
||||
context_types=context_types,
|
||||
updater=updater,
|
||||
concurrent_updates=-1,
|
||||
post_init=None,
|
||||
post_shutdown=None,
|
||||
post_stop=None,
|
||||
)
|
||||
|
||||
def test_job_queue(self, one_time_bot, app, recwarn):
|
||||
expected_warning = (
|
||||
"No `JobQueue` set up. To use `JobQueue`, you must install PTB via "
|
||||
|
@ -250,23 +235,39 @@ class TestApplication:
|
|||
async def after_initialize_bot(*args, **kwargs):
|
||||
self.test_flag.add("bot")
|
||||
|
||||
async def after_initialize_update_processor(*args, **kwargs):
|
||||
self.test_flag.add("update_processor")
|
||||
|
||||
async def after_initialize_updater(*args, **kwargs):
|
||||
self.test_flag.add("updater")
|
||||
|
||||
update_processor = SimpleUpdateProcessor(1)
|
||||
monkeypatch.setattr(Bot, "initialize", call_after(Bot.initialize, after_initialize_bot))
|
||||
monkeypatch.setattr(
|
||||
SimpleUpdateProcessor,
|
||||
"initialize",
|
||||
call_after(SimpleUpdateProcessor.initialize, after_initialize_update_processor),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
Updater, "initialize", call_after(Updater.initialize, after_initialize_updater)
|
||||
)
|
||||
|
||||
if updater:
|
||||
app = ApplicationBuilder().bot(one_time_bot).build()
|
||||
app = (
|
||||
ApplicationBuilder().bot(one_time_bot).concurrent_updates(update_processor).build()
|
||||
)
|
||||
await app.initialize()
|
||||
assert self.test_flag == {"bot", "updater"}
|
||||
assert self.test_flag == {"bot", "update_processor", "updater"}
|
||||
await app.shutdown()
|
||||
else:
|
||||
app = ApplicationBuilder().bot(one_time_bot).updater(None).build()
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.bot(one_time_bot)
|
||||
.updater(None)
|
||||
.concurrent_updates(update_processor)
|
||||
.build()
|
||||
)
|
||||
await app.initialize()
|
||||
assert self.test_flag == {"bot"}
|
||||
assert self.test_flag == {"bot", "update_processor"}
|
||||
await app.shutdown()
|
||||
|
||||
@pytest.mark.parametrize("updater", [True, False])
|
||||
|
@ -277,22 +278,35 @@ class TestApplication:
|
|||
def after_bot_shutdown(*args, **kwargs):
|
||||
self.test_flag.add("bot")
|
||||
|
||||
def after_shutdown_update_processor(*args, **kwargs):
|
||||
self.test_flag.add("update_processor")
|
||||
|
||||
def after_updater_shutdown(*args, **kwargs):
|
||||
self.test_flag.add("updater")
|
||||
|
||||
update_processor = SimpleUpdateProcessor(1)
|
||||
monkeypatch.setattr(Bot, "shutdown", call_after(Bot.shutdown, after_bot_shutdown))
|
||||
monkeypatch.setattr(
|
||||
SimpleUpdateProcessor,
|
||||
"shutdown",
|
||||
call_after(SimpleUpdateProcessor.shutdown, after_shutdown_update_processor),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
Updater, "shutdown", call_after(Updater.shutdown, after_updater_shutdown)
|
||||
)
|
||||
|
||||
if updater:
|
||||
async with ApplicationBuilder().bot(one_time_bot).build():
|
||||
async with ApplicationBuilder().bot(one_time_bot).concurrent_updates(
|
||||
update_processor
|
||||
).build():
|
||||
pass
|
||||
assert self.test_flag == {"bot", "updater"}
|
||||
assert self.test_flag == {"bot", "update_processor", "updater"}
|
||||
else:
|
||||
async with ApplicationBuilder().bot(one_time_bot).updater(None).build():
|
||||
async with ApplicationBuilder().bot(one_time_bot).updater(None).concurrent_updates(
|
||||
update_processor
|
||||
).build():
|
||||
pass
|
||||
assert self.test_flag == {"bot"}
|
||||
assert self.test_flag == {"bot", "update_processor"}
|
||||
|
||||
async def test_multiple_inits_and_shutdowns(self, app, monkeypatch):
|
||||
self.received = defaultdict(int)
|
||||
|
@ -1309,7 +1323,7 @@ class TestApplication:
|
|||
await app.create_task(gen())
|
||||
assert event.is_set()
|
||||
|
||||
async def test_no_concurrent_updates(self, app):
|
||||
async def test_no_update_processor(self, app):
|
||||
queue = asyncio.Queue()
|
||||
event_1 = asyncio.Event()
|
||||
event_2 = asyncio.Event()
|
||||
|
@ -1337,14 +1351,14 @@ class TestApplication:
|
|||
|
||||
await app.stop()
|
||||
|
||||
@pytest.mark.parametrize("concurrent_updates", [15, 50, 100])
|
||||
async def test_concurrent_updates(self, one_time_bot, concurrent_updates):
|
||||
@pytest.mark.parametrize("update_processor", [15, 50, 100])
|
||||
async def test_update_processor(self, one_time_bot, update_processor):
|
||||
# We don't test with `True` since the large number of parallel coroutines quickly leads
|
||||
# to test instabilities
|
||||
app = (
|
||||
Application.builder().bot(one_time_bot).concurrent_updates(concurrent_updates).build()
|
||||
)
|
||||
events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)}
|
||||
app = Application.builder().bot(one_time_bot).concurrent_updates(update_processor).build()
|
||||
events = {
|
||||
i: asyncio.Event() for i in range(app.update_processor.max_concurrent_updates + 10)
|
||||
}
|
||||
queue = asyncio.Queue()
|
||||
for event in events.values():
|
||||
await queue.put(event)
|
||||
|
@ -1356,25 +1370,28 @@ class TestApplication:
|
|||
app.add_handler(TypeHandler(object, callback))
|
||||
async with app:
|
||||
await app.start()
|
||||
for i in range(app.concurrent_updates + 10):
|
||||
for i in range(app.update_processor.max_concurrent_updates + 10):
|
||||
await app.update_queue.put(i)
|
||||
|
||||
for i in range(app.concurrent_updates + 10):
|
||||
for i in range(app.update_processor.max_concurrent_updates + 10):
|
||||
assert not events[i].is_set()
|
||||
|
||||
await asyncio.sleep(0.9)
|
||||
for i in range(app.concurrent_updates):
|
||||
for i in range(app.update_processor.max_concurrent_updates):
|
||||
assert events[i].is_set()
|
||||
for i in range(app.concurrent_updates, app.concurrent_updates + 10):
|
||||
for i in range(
|
||||
app.update_processor.max_concurrent_updates,
|
||||
app.update_processor.max_concurrent_updates + 10,
|
||||
):
|
||||
assert not events[i].is_set()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
for i in range(app.concurrent_updates + 10):
|
||||
for i in range(app.update_processor.max_concurrent_updates + 10):
|
||||
assert events[i].is_set()
|
||||
|
||||
await app.stop()
|
||||
|
||||
async def test_concurrent_updates_done_on_shutdown(self, one_time_bot):
|
||||
async def test_update_processor_done_on_shutdown(self, one_time_bot):
|
||||
app = Application.builder().bot(one_time_bot).concurrent_updates(True).build()
|
||||
event = asyncio.Event()
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ from telegram.ext import (
|
|||
Updater,
|
||||
)
|
||||
from telegram.ext._applicationbuilder import _BOT_CHECKS
|
||||
from telegram.ext._baseupdateprocessor import SimpleUpdateProcessor
|
||||
from telegram.request import HTTPXRequest
|
||||
from tests.auxil.constants import PRIVATE_KEY
|
||||
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
|
||||
|
@ -96,7 +97,8 @@ class TestApplicationBuilder:
|
|||
app = builder.token(bot.token).build()
|
||||
|
||||
assert isinstance(app, Application)
|
||||
assert app.concurrent_updates == 0
|
||||
assert isinstance(app.update_processor, SimpleUpdateProcessor)
|
||||
assert app.update_processor.max_concurrent_updates == 1
|
||||
|
||||
assert isinstance(app.bot, ExtBot)
|
||||
assert isinstance(app.bot.request, HTTPXRequest)
|
||||
|
@ -367,12 +369,21 @@ class TestApplicationBuilder:
|
|||
assert isinstance(app, CustomApplication)
|
||||
assert app.arg == 2
|
||||
|
||||
def test_all_application_args_custom(self, builder, bot, monkeypatch):
|
||||
@pytest.mark.parametrize(
|
||||
("concurrent_updates", "expected"),
|
||||
[
|
||||
(4, SimpleUpdateProcessor(4)),
|
||||
(False, SimpleUpdateProcessor(1)),
|
||||
(True, SimpleUpdateProcessor(256)),
|
||||
],
|
||||
)
|
||||
def test_all_application_args_custom(
|
||||
self, builder, bot, monkeypatch, concurrent_updates, expected
|
||||
):
|
||||
job_queue = JobQueue()
|
||||
persistence = PicklePersistence("file_path")
|
||||
update_queue = asyncio.Queue()
|
||||
context_types = ContextTypes()
|
||||
concurrent_updates = 123
|
||||
|
||||
async def post_init(app: Application) -> None:
|
||||
pass
|
||||
|
@ -395,6 +406,7 @@ class TestApplicationBuilder:
|
|||
.post_stop(post_stop)
|
||||
.arbitrary_callback_data(True)
|
||||
).build()
|
||||
|
||||
assert app.job_queue is job_queue
|
||||
assert app.job_queue.application is app
|
||||
assert app.persistence is persistence
|
||||
|
@ -403,7 +415,9 @@ class TestApplicationBuilder:
|
|||
assert app.updater.update_queue is update_queue
|
||||
assert app.updater.bot is app.bot
|
||||
assert app.context_types is context_types
|
||||
assert app.concurrent_updates == concurrent_updates
|
||||
assert isinstance(app.update_processor, SimpleUpdateProcessor)
|
||||
assert app.update_processor.max_concurrent_updates == expected.max_concurrent_updates
|
||||
assert app.concurrent_updates == app.update_processor.max_concurrent_updates
|
||||
assert app.post_init is post_init
|
||||
assert app.post_shutdown is post_shutdown
|
||||
assert app.post_stop is post_stop
|
||||
|
@ -414,6 +428,19 @@ class TestApplicationBuilder:
|
|||
assert app.updater is updater
|
||||
assert app.bot is updater.bot
|
||||
assert app.update_queue is updater.update_queue
|
||||
app = (
|
||||
builder.token(bot.token)
|
||||
.job_queue(job_queue)
|
||||
.persistence(persistence)
|
||||
.update_queue(update_queue)
|
||||
.context_types(context_types)
|
||||
.concurrent_updates(expected)
|
||||
.post_init(post_init)
|
||||
.post_shutdown(post_shutdown)
|
||||
.post_stop(post_stop)
|
||||
.arbitrary_callback_data(True)
|
||||
).build()
|
||||
assert app.update_processor is expected
|
||||
|
||||
@pytest.mark.parametrize("input_type", ["bytes", "str", "Path"])
|
||||
def test_all_private_key_input_types(self, builder, bot, input_type):
|
||||
|
|
166
tests/ext/test_baseupdateprocessor.py
Normal file
166
tests/ext/test_baseupdateprocessor.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# A library that provides a Python interface to the Telegram Bot API
|
||||
# Copyright (C) 2015-2023
|
||||
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""Here we run tests directly with SimpleUpdateProcessor because that's easier than providing dummy
|
||||
implementations for SimpleUpdateProcessor and we want to test SimpleUpdateProcessor anyway."""
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import SimpleUpdateProcessor
|
||||
from tests.auxil.asyncio_helpers import call_after
|
||||
from tests.auxil.slots import mro_slots
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_processor():
|
||||
class MockProcessor(SimpleUpdateProcessor):
|
||||
test_flag = False
|
||||
|
||||
async def do_process_update(self, update, coroutine):
|
||||
await coroutine
|
||||
self.test_flag = True
|
||||
|
||||
return MockProcessor(5)
|
||||
|
||||
|
||||
class TestSimpleUpdateProcessor:
|
||||
def test_slot_behaviour(self):
|
||||
inst = SimpleUpdateProcessor(1)
|
||||
for attr in inst.__slots__:
|
||||
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"
|
||||
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
|
||||
|
||||
@pytest.mark.parametrize("concurrent_updates", [0, -1])
|
||||
def test_init(self, concurrent_updates):
|
||||
processor = SimpleUpdateProcessor(3)
|
||||
assert processor.max_concurrent_updates == 3
|
||||
with pytest.raises(ValueError, match="must be a positive integer"):
|
||||
SimpleUpdateProcessor(concurrent_updates)
|
||||
|
||||
async def test_process_update(self, mock_processor):
|
||||
"""Test that process_update calls do_process_update."""
|
||||
update = Update(1)
|
||||
|
||||
async def coroutine():
|
||||
pass
|
||||
|
||||
await mock_processor.process_update(update, coroutine())
|
||||
# This flag is set in the mock processor in do_process_update, telling us that
|
||||
# do_process_update was called.
|
||||
assert mock_processor.test_flag
|
||||
|
||||
async def test_do_process_update(self):
|
||||
"""Test that do_process_update calls the coroutine."""
|
||||
processor = SimpleUpdateProcessor(1)
|
||||
update = Update(1)
|
||||
test_flag = False
|
||||
|
||||
async def coroutine():
|
||||
nonlocal test_flag
|
||||
test_flag = True
|
||||
|
||||
await processor.do_process_update(update, coroutine())
|
||||
assert test_flag
|
||||
|
||||
async def test_max_concurrent_updates_enforcement(self, mock_processor):
|
||||
"""Test that max_concurrent_updates is enforced, i.e. that the processor will run
|
||||
at most max_concurrent_updates coroutines at the same time."""
|
||||
count = 2 * mock_processor.max_concurrent_updates
|
||||
events = {i: asyncio.Event() for i in range(count)}
|
||||
queue = asyncio.Queue()
|
||||
for event in events.values():
|
||||
await queue.put(event)
|
||||
|
||||
async def callback():
|
||||
await asyncio.sleep(0.5)
|
||||
(await queue.get()).set()
|
||||
|
||||
# We start several calls to `process_update` at the same time, each of them taking
|
||||
# 0.5 seconds to complete. We know that they are completed when the corresponding
|
||||
# event is set.
|
||||
tasks = [
|
||||
asyncio.create_task(mock_processor.process_update(update=_, coroutine=callback()))
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# Right now we expect no event to be set
|
||||
for i in range(count):
|
||||
assert not events[i].is_set()
|
||||
|
||||
# After 0.5 seconds (+ some buffer), we expect that exactly max_concurrent_updates
|
||||
# events are set.
|
||||
await asyncio.sleep(0.75)
|
||||
for i in range(mock_processor.max_concurrent_updates):
|
||||
assert events[i].is_set()
|
||||
for i in range(
|
||||
mock_processor.max_concurrent_updates,
|
||||
count,
|
||||
):
|
||||
assert not events[i].is_set()
|
||||
|
||||
# After wating another 0.5 seconds, we expect that the next max_concurrent_updates
|
||||
# events are set.
|
||||
await asyncio.sleep(0.5)
|
||||
for i in range(count):
|
||||
assert events[i].is_set()
|
||||
|
||||
# Sanity check: we expect that all tasks are completed.
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def test_context_manager(self, monkeypatch, mock_processor):
|
||||
self.test_flag = set()
|
||||
|
||||
async def after_initialize(*args, **kwargs):
|
||||
self.test_flag.add("initialize")
|
||||
|
||||
async def after_shutdown(*args, **kwargs):
|
||||
self.test_flag.add("stop")
|
||||
|
||||
monkeypatch.setattr(
|
||||
SimpleUpdateProcessor,
|
||||
"initialize",
|
||||
call_after(SimpleUpdateProcessor.initialize, after_initialize),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
SimpleUpdateProcessor,
|
||||
"shutdown",
|
||||
call_after(SimpleUpdateProcessor.shutdown, after_shutdown),
|
||||
)
|
||||
|
||||
async with mock_processor:
|
||||
pass
|
||||
|
||||
assert self.test_flag == {"initialize", "stop"}
|
||||
|
||||
async def test_context_manager_exception_on_init(self, monkeypatch, mock_processor):
|
||||
async def initialize(*args, **kwargs):
|
||||
raise RuntimeError("initialize")
|
||||
|
||||
async def shutdown(*args, **kwargs):
|
||||
self.test_flag = "shutdown"
|
||||
|
||||
monkeypatch.setattr(SimpleUpdateProcessor, "initialize", initialize)
|
||||
monkeypatch.setattr(SimpleUpdateProcessor, "shutdown", shutdown)
|
||||
|
||||
with pytest.raises(RuntimeError, match="initialize"):
|
||||
async with mock_processor:
|
||||
pass
|
||||
|
||||
assert self.test_flag == "shutdown"
|
Loading…
Add table
Reference in a new issue