Introduce BaseUpdateProcessor for Customized Concurrent Handling of Updates (#3654)

Co-authored-by: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
Aditya Yadav 2023-06-02 21:47:08 +05:30 committed by GitHub
parent 4c8d7332db
commit bf54599618
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 496 additions and 72 deletions

View file

@ -0,0 +1,6 @@
BaseUpdateProcessor
===================
.. autoclass:: telegram.ext.BaseUpdateProcessor
:members:
:show-inheritance:

View file

@ -9,12 +9,14 @@ telegram.ext package
telegram.ext.application
telegram.ext.applicationbuilder
telegram.ext.applicationhandlerstop
telegram.ext.baseupdateprocessor
telegram.ext.callbackcontext
telegram.ext.contexttypes
telegram.ext.defaults
telegram.ext.extbot
telegram.ext.job
telegram.ext.jobqueue
telegram.ext.simpleupdateprocessor
telegram.ext.updater
telegram.ext.handlers-tree.rst
telegram.ext.persistence-tree.rst

View file

@ -0,0 +1,6 @@
SimpleUpdateProcessor
=====================
.. autoclass:: telegram.ext.SimpleUpdateProcessor
:members:
:show-inheritance:

View file

@ -26,6 +26,7 @@ __all__ = (
"BaseHandler",
"BasePersistence",
"BaseRateLimiter",
"BaseUpdateProcessor",
"CallbackContext",
"CallbackDataCache",
"CallbackQueryHandler",
@ -51,6 +52,7 @@ __all__ = (
"PreCheckoutQueryHandler",
"PrefixHandler",
"ShippingQueryHandler",
"SimpleUpdateProcessor",
"StringCommandHandler",
"StringRegexHandler",
"TypeHandler",
@ -63,6 +65,7 @@ from ._application import Application, ApplicationHandlerStop
from ._applicationbuilder import ApplicationBuilder
from ._basepersistence import BasePersistence, PersistenceInput
from ._baseratelimiter import BaseRateLimiter
from ._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
from ._callbackcontext import CallbackContext
from ._callbackdatacache import CallbackDataCache, InvalidCallbackData
from ._callbackqueryhandler import CallbackQueryHandler

View file

@ -57,6 +57,7 @@ from telegram._utils.types import SCT, DVType, ODVInput
from telegram._utils.warnings import warn
from telegram.error import TelegramError
from telegram.ext._basepersistence import BasePersistence
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor
from telegram.ext._contexttypes import ContextTypes
from telegram.ext._extbot import ExtBot
from telegram.ext._handler import BaseHandler
@ -228,12 +229,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
"_chat_data",
"_chat_ids_to_be_deleted_in_persistence",
"_chat_ids_to_be_updated_in_persistence",
"_concurrent_updates",
"_concurrent_updates_sem",
"_conversation_handler_conversations",
"_initialized",
"_job_queue",
"_running",
"_update_processor",
"_user_data",
"_user_ids_to_be_deleted_in_persistence",
"_user_ids_to_be_updated_in_persistence",
@ -259,7 +259,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
update_queue: "asyncio.Queue[object]",
updater: Optional[Updater],
job_queue: JQ,
concurrent_updates: Union[bool, int],
update_processor: "BaseUpdateProcessor",
persistence: Optional[BasePersistence[UD, CD, BD]],
context_types: ContextTypes[CCT, UD, CD, BD],
post_init: Optional[
@ -297,14 +297,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
self.post_stop: Optional[
Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]]
] = post_stop
if isinstance(concurrent_updates, int) and concurrent_updates < 0:
raise ValueError("`concurrent_updates` must be a non-negative integer!")
if concurrent_updates is True:
concurrent_updates = 256
self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1)
self._concurrent_updates: int = concurrent_updates or 0
self._update_processor = update_processor
self.bot_data: BD = self.context_types.bot_data()
self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data)
self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data)
@ -359,9 +352,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
""":obj:`int`: The number of concurrent updates that will be processed in parallel. A
value of ``0`` indicates updates are *not* being processed concurrently.
.. versionchanged:: NEXT.VERSION
This is now just a shortcut to :attr:`update_processor.max_concurrent_updates
<telegram.ext.BaseUpdateProcessor.max_concurrent_updates>`.
.. seealso:: :wiki:`Concurrency`
"""
return self._concurrent_updates
return self._update_processor.max_concurrent_updates
@property
def job_queue(self) -> Optional["JobQueue[CCT]"]:
@ -379,12 +376,25 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
)
return self._job_queue
@property
def update_processor(self) -> "BaseUpdateProcessor":
""":class:`telegram.ext.BaseUpdateProcessor`: The update processor used by this
application.
.. seealso:: :wiki:`Concurrency`
.. versionadded:: NEXT.VERSION
"""
return self._update_processor
async def initialize(self) -> None:
"""Initializes the Application by initializing:
* The :attr:`bot`, by calling :meth:`telegram.Bot.initialize`.
* The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`.
* The :attr:`persistence`, by loading persistent conversations and data.
* The :attr:`update_processor` by calling
:meth:`telegram.ext.BaseUpdateProcessor.initialize`.
Does *not* call :attr:`post_init` - that is only done by :meth:`run_polling` and
:meth:`run_webhook`.
@ -397,6 +407,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
return
await self.bot.initialize()
await self._update_processor.initialize()
if self.updater:
await self.updater.initialize()
@ -429,6 +441,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
* :attr:`updater` by calling :meth:`telegram.ext.Updater.shutdown`
* :attr:`persistence` by calling :meth:`update_persistence` and
:meth:`BasePersistence.flush`
* :attr:`update_processor` by calling :meth:`telegram.ext.BaseUpdateProcessor.shutdown`
Does *not* call :attr:`post_shutdown` - that is only done by :meth:`run_polling` and
:meth:`run_webhook`.
@ -447,6 +460,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
return
await self.bot.shutdown()
await self._update_processor.shutdown()
if self.updater:
await self.updater.shutdown()
@ -1060,11 +1075,15 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
_LOGGER.debug("Processing update %s", update)
if self._concurrent_updates:
if self._update_processor.max_concurrent_updates > 1:
# We don't await the below because it has to be run concurrently
self.create_task(self.__process_update_wrapper(update), update=update)
self.create_task(
self.__process_update_wrapper(update),
update=update,
)
else:
await self.__process_update_wrapper(update)
except asyncio.CancelledError:
# This may happen if the application is manually run via application.start() and
# then a KeyboardInterrupt is sent. We must prevent this loop to die since
@ -1075,9 +1094,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
)
async def __process_update_wrapper(self, update: object) -> None:
async with self._concurrent_updates_sem:
await self.process_update(update)
self.update_queue.task_done()
await self._update_processor.process_update(update, self.process_update(update))
self.update_queue.task_done()
async def process_update(self, update: object) -> None:
"""Processes a single update and marks the update to be updated by the persistence later.

View file

@ -36,6 +36,7 @@ from telegram._bot import Bot
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
from telegram._utils.types import DVInput, DVType, FilePathInput, ODVInput
from telegram.ext._application import Application
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
from telegram.ext._contexttypes import ContextTypes
from telegram.ext._extbot import ExtBot
from telegram.ext._jobqueue import JobQueue
@ -127,7 +128,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_base_file_url",
"_base_url",
"_bot",
"_concurrent_updates",
"_update_processor",
"_connect_timeout",
"_connection_pool_size",
"_context_types",
@ -198,7 +199,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._context_types: DVType[ContextTypes] = DefaultValue(ContextTypes())
self._application_class: DVType[Type[Application]] = DefaultValue(Application)
self._application_kwargs: Dict[str, object] = {}
self._concurrent_updates: Union[int, DefaultValue[bool]] = DEFAULT_FALSE
self._update_processor: "BaseUpdateProcessor" = SimpleUpdateProcessor(
max_concurrent_updates=1
)
self._updater: ODVInput[Updater] = DEFAULT_NONE
self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
self._post_shutdown: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None
@ -306,7 +309,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
bot=bot,
update_queue=update_queue,
updater=updater,
concurrent_updates=DefaultValue.get_value(self._concurrent_updates),
update_processor=self._update_processor,
job_queue=job_queue,
persistence=persistence,
context_types=DefaultValue.get_value(self._context_types),
@ -902,7 +905,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._update_queue = update_queue
return self
def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) -> BuilderType:
def concurrent_updates(
self: BuilderType, concurrent_updates: Union[bool, int, "BaseUpdateProcessor"]
) -> BuilderType:
"""Specifies if and how many updates may be processed concurrently instead of one by one.
If not called, updates will be processed one by one.
@ -917,14 +922,34 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
.. seealso:: :attr:`telegram.ext.Application.concurrent_updates`
Args:
concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for
``256`` updates to be processed concurrently. Pass an integer to specify a
different number of updates that may be processed concurrently.
concurrent_updates (:obj:`bool` | :obj:`int` | :class:`BaseUpdateProcessor`): Passing
:obj:`True` will allow for ``256`` updates to be processed concurrently using
:class:`telegram.ext.SimpleUpdateProcessor`. Pass an integer to specify a different
number of updates that may be processed concurrently. Pass an instance of
:class:`telegram.ext.BaseUpdateProcessor` to use that instance for handling updates
concurrently.
.. versionchanged:: NEXT.VERSION
Now accepts :class:`BaseUpdateProcessor` instances.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._concurrent_updates = concurrent_updates
# Check if concurrent updates is bool and convert to integer
if concurrent_updates is True:
concurrent_updates = 256
elif concurrent_updates is False:
concurrent_updates = 1
# If `concurrent_updates` is an integer, create a `SimpleUpdateProcessor`
# instance with that integer value; otherwise, raise an error if the value
# is negative
if isinstance(concurrent_updates, int):
concurrent_updates = SimpleUpdateProcessor(concurrent_updates)
# Assign default value of concurrent_updates if it is instance of
# `BaseUpdateProcessor`
self._update_processor: BaseUpdateProcessor = concurrent_updates # type: ignore[no-redef]
return self
def job_queue(

View file

@ -0,0 +1,154 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2023
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BaseProcessor class."""
from abc import ABC, abstractmethod
from asyncio import BoundedSemaphore
from types import TracebackType
from typing import Any, Awaitable, Optional, Type
class BaseUpdateProcessor(ABC):
"""An abstract base class for update processors. You can use this class to implement
your own update processor.
.. seealso:: :wiki:`Concurrency`
.. versionadded:: NEXT.VERSION
Args:
max_concurrent_updates (:obj:`int`): The maximum number of updates to be processed
concurrently. If this number is exceeded, new updates will be queued until the number
of currently processed updates decreases.
Raises:
:exc:`ValueError`: If :paramref:`max_concurrent_updates` is a non-positive integer.
"""
__slots__ = ("_max_concurrent_updates", "_semaphore")
def __init__(self, max_concurrent_updates: int):
self._max_concurrent_updates = max_concurrent_updates
if self.max_concurrent_updates < 1:
raise ValueError("`max_concurrent_updates` must be a positive integer!")
self._semaphore = BoundedSemaphore(self.max_concurrent_updates)
@property
def max_concurrent_updates(self) -> int:
""":obj:`int`: The maximum number of updates that can be processed concurrently."""
return self._max_concurrent_updates
@abstractmethod
async def do_process_update(
self,
update: object,
coroutine: "Awaitable[Any]",
) -> None:
"""Custom implementation of how to process an update. Must be implemented by a subclass.
Warning:
This method will be called by :meth:`process_update`. It should *not* be called
manually.
Args:
update (:obj:`object`): The update to be processed.
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
update.
"""
@abstractmethod
async def initialize(self) -> None:
"""Initializes the processor so resources can be allocated. Must be implemented by a
subclass.
.. seealso::
:meth:`shutdown`
"""
@abstractmethod
async def shutdown(self) -> None:
"""Shutdown the processor so resources can be freed. Must be implemented by a subclass.
.. seealso::
:meth:`initialize`
"""
async def process_update(
self,
update: object,
coroutine: "Awaitable[Any]",
) -> None:
"""Calls :meth:`do_process_update` with a semaphore to limit the number of concurrent
updates.
Args:
update (:obj:`object`): The update to be processed.
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
update.
"""
async with self._semaphore:
await self.do_process_update(update, coroutine)
async def __aenter__(self) -> "BaseUpdateProcessor":
"""Simple context manager which initializes the Processor."""
try:
await self.initialize()
return self
except Exception as exc:
await self.shutdown()
raise exc
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Shutdown the Processor from the context manager."""
await self.shutdown()
class SimpleUpdateProcessor(BaseUpdateProcessor):
"""Instance of :class:`telegram.ext.BaseUpdateProcessor` that immediately awaits the
coroutine, i.e. does not apply any additional processing. This is used by default when
:attr:`telegram.ext.ApplicationBuilder.concurrent_updates` is :obj:`int`.
.. versionadded:: NEXT.VERSION
"""
__slots__ = ()
async def do_process_update(
self,
update: object,
coroutine: "Awaitable[Any]",
) -> None:
"""Immediately awaits the coroutine, i.e. does not apply any additional processing.
Args:
update (:obj:`object`): The update to be processed.
coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the
update.
"""
await coroutine
async def initialize(self) -> None:
"""Does nothing."""
async def shutdown(self) -> None:
"""Does nothing."""

View file

@ -49,6 +49,7 @@ from telegram.ext import (
JobQueue,
MessageHandler,
PicklePersistence,
SimpleUpdateProcessor,
TypeHandler,
Updater,
filters,
@ -134,7 +135,7 @@ class TestApplication:
persistence=None,
context_types=ContextTypes(),
updater=updater,
concurrent_updates=False,
update_processor=False,
post_init=None,
post_shutdown=None,
post_stop=None,
@ -147,15 +148,13 @@ class TestApplication:
assert recwarn[0].category is PTBUserWarning
assert recwarn[0].filename == __file__, "stacklevel is incorrect!"
@pytest.mark.parametrize(
("concurrent_updates", "expected"), [(0, 0), (4, 4), (False, 0), (True, 256)]
)
@pytest.mark.filterwarnings("ignore: `Application` instances should")
def test_init(self, one_time_bot, concurrent_updates, expected):
def test_init(self, one_time_bot):
update_queue = asyncio.Queue()
job_queue = JobQueue()
persistence = PicklePersistence("file_path")
context_types = ContextTypes()
update_processor = SimpleUpdateProcessor(1)
updater = Updater(bot=one_time_bot, update_queue=update_queue)
async def post_init(application: Application) -> None:
@ -174,7 +173,7 @@ class TestApplication:
persistence=persistence,
context_types=context_types,
updater=updater,
concurrent_updates=concurrent_updates,
update_processor=update_processor,
post_init=post_init,
post_shutdown=post_shutdown,
post_stop=post_stop,
@ -187,7 +186,7 @@ class TestApplication:
assert app.updater is updater
assert app.update_queue is updater.update_queue
assert app.bot is updater.bot
assert app.concurrent_updates == expected
assert app.update_processor is update_processor
assert app.post_init is post_init
assert app.post_shutdown is post_shutdown
assert app.post_stop is post_stop
@ -201,20 +200,6 @@ class TestApplication:
assert isinstance(app.chat_data[1], dict)
assert isinstance(app.user_data[1], dict)
with pytest.raises(ValueError, match="must be a non-negative"):
Application(
bot=one_time_bot,
update_queue=update_queue,
job_queue=job_queue,
persistence=persistence,
context_types=context_types,
updater=updater,
concurrent_updates=-1,
post_init=None,
post_shutdown=None,
post_stop=None,
)
def test_job_queue(self, one_time_bot, app, recwarn):
expected_warning = (
"No `JobQueue` set up. To use `JobQueue`, you must install PTB via "
@ -250,23 +235,39 @@ class TestApplication:
async def after_initialize_bot(*args, **kwargs):
self.test_flag.add("bot")
async def after_initialize_update_processor(*args, **kwargs):
self.test_flag.add("update_processor")
async def after_initialize_updater(*args, **kwargs):
self.test_flag.add("updater")
update_processor = SimpleUpdateProcessor(1)
monkeypatch.setattr(Bot, "initialize", call_after(Bot.initialize, after_initialize_bot))
monkeypatch.setattr(
SimpleUpdateProcessor,
"initialize",
call_after(SimpleUpdateProcessor.initialize, after_initialize_update_processor),
)
monkeypatch.setattr(
Updater, "initialize", call_after(Updater.initialize, after_initialize_updater)
)
if updater:
app = ApplicationBuilder().bot(one_time_bot).build()
app = (
ApplicationBuilder().bot(one_time_bot).concurrent_updates(update_processor).build()
)
await app.initialize()
assert self.test_flag == {"bot", "updater"}
assert self.test_flag == {"bot", "update_processor", "updater"}
await app.shutdown()
else:
app = ApplicationBuilder().bot(one_time_bot).updater(None).build()
app = (
ApplicationBuilder()
.bot(one_time_bot)
.updater(None)
.concurrent_updates(update_processor)
.build()
)
await app.initialize()
assert self.test_flag == {"bot"}
assert self.test_flag == {"bot", "update_processor"}
await app.shutdown()
@pytest.mark.parametrize("updater", [True, False])
@ -277,22 +278,35 @@ class TestApplication:
def after_bot_shutdown(*args, **kwargs):
self.test_flag.add("bot")
def after_shutdown_update_processor(*args, **kwargs):
self.test_flag.add("update_processor")
def after_updater_shutdown(*args, **kwargs):
self.test_flag.add("updater")
update_processor = SimpleUpdateProcessor(1)
monkeypatch.setattr(Bot, "shutdown", call_after(Bot.shutdown, after_bot_shutdown))
monkeypatch.setattr(
SimpleUpdateProcessor,
"shutdown",
call_after(SimpleUpdateProcessor.shutdown, after_shutdown_update_processor),
)
monkeypatch.setattr(
Updater, "shutdown", call_after(Updater.shutdown, after_updater_shutdown)
)
if updater:
async with ApplicationBuilder().bot(one_time_bot).build():
async with ApplicationBuilder().bot(one_time_bot).concurrent_updates(
update_processor
).build():
pass
assert self.test_flag == {"bot", "updater"}
assert self.test_flag == {"bot", "update_processor", "updater"}
else:
async with ApplicationBuilder().bot(one_time_bot).updater(None).build():
async with ApplicationBuilder().bot(one_time_bot).updater(None).concurrent_updates(
update_processor
).build():
pass
assert self.test_flag == {"bot"}
assert self.test_flag == {"bot", "update_processor"}
async def test_multiple_inits_and_shutdowns(self, app, monkeypatch):
self.received = defaultdict(int)
@ -1309,7 +1323,7 @@ class TestApplication:
await app.create_task(gen())
assert event.is_set()
async def test_no_concurrent_updates(self, app):
async def test_no_update_processor(self, app):
queue = asyncio.Queue()
event_1 = asyncio.Event()
event_2 = asyncio.Event()
@ -1337,14 +1351,14 @@ class TestApplication:
await app.stop()
@pytest.mark.parametrize("concurrent_updates", [15, 50, 100])
async def test_concurrent_updates(self, one_time_bot, concurrent_updates):
@pytest.mark.parametrize("update_processor", [15, 50, 100])
async def test_update_processor(self, one_time_bot, update_processor):
# We don't test with `True` since the large number of parallel coroutines quickly leads
# to test instabilities
app = (
Application.builder().bot(one_time_bot).concurrent_updates(concurrent_updates).build()
)
events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)}
app = Application.builder().bot(one_time_bot).concurrent_updates(update_processor).build()
events = {
i: asyncio.Event() for i in range(app.update_processor.max_concurrent_updates + 10)
}
queue = asyncio.Queue()
for event in events.values():
await queue.put(event)
@ -1356,25 +1370,28 @@ class TestApplication:
app.add_handler(TypeHandler(object, callback))
async with app:
await app.start()
for i in range(app.concurrent_updates + 10):
for i in range(app.update_processor.max_concurrent_updates + 10):
await app.update_queue.put(i)
for i in range(app.concurrent_updates + 10):
for i in range(app.update_processor.max_concurrent_updates + 10):
assert not events[i].is_set()
await asyncio.sleep(0.9)
for i in range(app.concurrent_updates):
for i in range(app.update_processor.max_concurrent_updates):
assert events[i].is_set()
for i in range(app.concurrent_updates, app.concurrent_updates + 10):
for i in range(
app.update_processor.max_concurrent_updates,
app.update_processor.max_concurrent_updates + 10,
):
assert not events[i].is_set()
await asyncio.sleep(0.5)
for i in range(app.concurrent_updates + 10):
for i in range(app.update_processor.max_concurrent_updates + 10):
assert events[i].is_set()
await app.stop()
async def test_concurrent_updates_done_on_shutdown(self, one_time_bot):
async def test_update_processor_done_on_shutdown(self, one_time_bot):
app = Application.builder().bot(one_time_bot).concurrent_updates(True).build()
event = asyncio.Event()

View file

@ -35,6 +35,7 @@ from telegram.ext import (
Updater,
)
from telegram.ext._applicationbuilder import _BOT_CHECKS
from telegram.ext._baseupdateprocessor import SimpleUpdateProcessor
from telegram.request import HTTPXRequest
from tests.auxil.constants import PRIVATE_KEY
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
@ -96,7 +97,8 @@ class TestApplicationBuilder:
app = builder.token(bot.token).build()
assert isinstance(app, Application)
assert app.concurrent_updates == 0
assert isinstance(app.update_processor, SimpleUpdateProcessor)
assert app.update_processor.max_concurrent_updates == 1
assert isinstance(app.bot, ExtBot)
assert isinstance(app.bot.request, HTTPXRequest)
@ -367,12 +369,21 @@ class TestApplicationBuilder:
assert isinstance(app, CustomApplication)
assert app.arg == 2
def test_all_application_args_custom(self, builder, bot, monkeypatch):
@pytest.mark.parametrize(
("concurrent_updates", "expected"),
[
(4, SimpleUpdateProcessor(4)),
(False, SimpleUpdateProcessor(1)),
(True, SimpleUpdateProcessor(256)),
],
)
def test_all_application_args_custom(
self, builder, bot, monkeypatch, concurrent_updates, expected
):
job_queue = JobQueue()
persistence = PicklePersistence("file_path")
update_queue = asyncio.Queue()
context_types = ContextTypes()
concurrent_updates = 123
async def post_init(app: Application) -> None:
pass
@ -395,6 +406,7 @@ class TestApplicationBuilder:
.post_stop(post_stop)
.arbitrary_callback_data(True)
).build()
assert app.job_queue is job_queue
assert app.job_queue.application is app
assert app.persistence is persistence
@ -403,7 +415,9 @@ class TestApplicationBuilder:
assert app.updater.update_queue is update_queue
assert app.updater.bot is app.bot
assert app.context_types is context_types
assert app.concurrent_updates == concurrent_updates
assert isinstance(app.update_processor, SimpleUpdateProcessor)
assert app.update_processor.max_concurrent_updates == expected.max_concurrent_updates
assert app.concurrent_updates == app.update_processor.max_concurrent_updates
assert app.post_init is post_init
assert app.post_shutdown is post_shutdown
assert app.post_stop is post_stop
@ -414,6 +428,19 @@ class TestApplicationBuilder:
assert app.updater is updater
assert app.bot is updater.bot
assert app.update_queue is updater.update_queue
app = (
builder.token(bot.token)
.job_queue(job_queue)
.persistence(persistence)
.update_queue(update_queue)
.context_types(context_types)
.concurrent_updates(expected)
.post_init(post_init)
.post_shutdown(post_shutdown)
.post_stop(post_stop)
.arbitrary_callback_data(True)
).build()
assert app.update_processor is expected
@pytest.mark.parametrize("input_type", ["bytes", "str", "Path"])
def test_all_private_key_input_types(self, builder, bot, input_type):

View file

@ -0,0 +1,166 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2023
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""Here we run tests directly with SimpleUpdateProcessor because that's easier than providing dummy
implementations for SimpleUpdateProcessor and we want to test SimpleUpdateProcessor anyway."""
import asyncio
import pytest
from telegram import Update
from telegram.ext import SimpleUpdateProcessor
from tests.auxil.asyncio_helpers import call_after
from tests.auxil.slots import mro_slots
@pytest.fixture()
def mock_processor():
class MockProcessor(SimpleUpdateProcessor):
test_flag = False
async def do_process_update(self, update, coroutine):
await coroutine
self.test_flag = True
return MockProcessor(5)
class TestSimpleUpdateProcessor:
def test_slot_behaviour(self):
inst = SimpleUpdateProcessor(1)
for attr in inst.__slots__:
assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'"
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
@pytest.mark.parametrize("concurrent_updates", [0, -1])
def test_init(self, concurrent_updates):
processor = SimpleUpdateProcessor(3)
assert processor.max_concurrent_updates == 3
with pytest.raises(ValueError, match="must be a positive integer"):
SimpleUpdateProcessor(concurrent_updates)
async def test_process_update(self, mock_processor):
"""Test that process_update calls do_process_update."""
update = Update(1)
async def coroutine():
pass
await mock_processor.process_update(update, coroutine())
# This flag is set in the mock processor in do_process_update, telling us that
# do_process_update was called.
assert mock_processor.test_flag
async def test_do_process_update(self):
"""Test that do_process_update calls the coroutine."""
processor = SimpleUpdateProcessor(1)
update = Update(1)
test_flag = False
async def coroutine():
nonlocal test_flag
test_flag = True
await processor.do_process_update(update, coroutine())
assert test_flag
async def test_max_concurrent_updates_enforcement(self, mock_processor):
"""Test that max_concurrent_updates is enforced, i.e. that the processor will run
at most max_concurrent_updates coroutines at the same time."""
count = 2 * mock_processor.max_concurrent_updates
events = {i: asyncio.Event() for i in range(count)}
queue = asyncio.Queue()
for event in events.values():
await queue.put(event)
async def callback():
await asyncio.sleep(0.5)
(await queue.get()).set()
# We start several calls to `process_update` at the same time, each of them taking
# 0.5 seconds to complete. We know that they are completed when the corresponding
# event is set.
tasks = [
asyncio.create_task(mock_processor.process_update(update=_, coroutine=callback()))
for _ in range(count)
]
# Right now we expect no event to be set
for i in range(count):
assert not events[i].is_set()
# After 0.5 seconds (+ some buffer), we expect that exactly max_concurrent_updates
# events are set.
await asyncio.sleep(0.75)
for i in range(mock_processor.max_concurrent_updates):
assert events[i].is_set()
for i in range(
mock_processor.max_concurrent_updates,
count,
):
assert not events[i].is_set()
# After wating another 0.5 seconds, we expect that the next max_concurrent_updates
# events are set.
await asyncio.sleep(0.5)
for i in range(count):
assert events[i].is_set()
# Sanity check: we expect that all tasks are completed.
await asyncio.gather(*tasks)
async def test_context_manager(self, monkeypatch, mock_processor):
self.test_flag = set()
async def after_initialize(*args, **kwargs):
self.test_flag.add("initialize")
async def after_shutdown(*args, **kwargs):
self.test_flag.add("stop")
monkeypatch.setattr(
SimpleUpdateProcessor,
"initialize",
call_after(SimpleUpdateProcessor.initialize, after_initialize),
)
monkeypatch.setattr(
SimpleUpdateProcessor,
"shutdown",
call_after(SimpleUpdateProcessor.shutdown, after_shutdown),
)
async with mock_processor:
pass
assert self.test_flag == {"initialize", "stop"}
async def test_context_manager_exception_on_init(self, monkeypatch, mock_processor):
async def initialize(*args, **kwargs):
raise RuntimeError("initialize")
async def shutdown(*args, **kwargs):
self.test_flag = "shutdown"
monkeypatch.setattr(SimpleUpdateProcessor, "initialize", initialize)
monkeypatch.setattr(SimpleUpdateProcessor, "shutdown", shutdown)
with pytest.raises(RuntimeError, match="initialize"):
async with mock_processor:
pass
assert self.test_flag == "shutdown"