Enhance Application.create_task (#3543)

This commit is contained in:
sam-mosleh 2023-02-20 21:53:27 +03:00 committed by GitHub
parent ee6c8a5995
commit c6b6b0a370
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 16 deletions

View file

@ -100,6 +100,7 @@ The following wonderful people contributed directly or indirectly to this projec
- `Riko Naka <https://github.com/rikonaka>`_
- `Rizlas <https://github.com/rizlas>`_
- `Sahil Sharma <https://github.com/sahilsharma811>`_
- `Sam Mosleh <https://github.com/sam-mosleh>`_
- `Sascha <https://github.com/saschalalala>`_
- `Shelomentsev D <https://github.com/shelomentsevd>`_
- `Shivam Saini <https://github.com/shivamsn97>`_

View file

@ -31,10 +31,12 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Awaitable,
Callable,
Coroutine,
DefaultDict,
Dict,
Generator,
Generic,
List,
Mapping,
@ -71,7 +73,6 @@ if TYPE_CHECKING:
DEFAULT_GROUP: int = 0
_AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name
_RT = TypeVar("_RT")
_STOP_SIGNAL = object()
_logger = logging.getLogger(__name__)
@ -934,7 +935,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.close()
def create_task(
self, coroutine: Coroutine[Any, Any, RT], update: object = None
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
) -> "asyncio.Task[RT]":
"""Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by
the :paramref:`coroutine` with :meth:`process_error`.
@ -948,7 +951,10 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
.. seealso:: :wiki:`Concurrency`
Args:
coroutine (:term:`coroutine function`): The coroutine to run as task.
coroutine (:term:`awaitable`): The awaitable to run as task.
.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
update (:obj:`object`, optional): If set, will be passed to :meth:`process_error`
as additional information for the error handlers. Moreover, the corresponding
:attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of
@ -960,13 +966,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
return self.__create_task(coroutine=coroutine, update=update)
def __create_task(
self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False
) -> asyncio.Task:
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
is_error_handler: bool = False,
) -> "asyncio.Task[RT]":
# Unfortunately, we can't know if `coroutine` runs one of the error handler functions
# but by passing `is_error_handler=True` from `process_error`, we can make sure that we
# get at most one recursion of the user calls `create_task` manually with an error handler
# function
task = asyncio.create_task(
task: "asyncio.Task[RT]" = asyncio.create_task(
self.__create_task_callback(
coroutine=coroutine, update=update, is_error_handler=is_error_handler
)
@ -995,11 +1004,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
async def __create_task_callback(
self,
coroutine: Coroutine[Any, Any, _RT],
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
is_error_handler: bool = False,
) -> _RT:
) -> RT:
try:
if isinstance(coroutine, Generator):
return await asyncio.create_task(coroutine)
return await coroutine
except asyncio.CancelledError as cancel:
# TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this
@ -1562,7 +1573,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
update: Optional[object],
error: Exception,
job: "Job[CCT]" = None,
coroutine: Coroutine[Any, Any, Any] = None,
coroutine: Union[
Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]
] = None,
) -> bool:
"""Processes an error by passing it to all error handlers registered with
:meth:`add_error_handler`. If one of the error handlers raises

View file

@ -20,14 +20,16 @@
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Awaitable,
Dict,
Generator,
Generic,
List,
Match,
NoReturn,
Optional,
Type,
Union,
)
from telegram._callbackquery import CallbackQuery
@ -37,7 +39,7 @@ from telegram.ext._extbot import ExtBot
from telegram.ext._utils.types import BD, BT, CD, UD
if TYPE_CHECKING:
from asyncio import Queue
from asyncio import Future, Queue
from telegram.ext import Application, Job, JobQueue # noqa: F401
from telegram.ext._utils.types import CCT
@ -96,8 +98,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
.. versionadded:: 20.0
Attributes:
coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the
error was caused by a coroutine run with :meth:`Application.create_task` or a handler
coroutine (:term:`awaitable`): Optional. Only present in error handlers if the
error was caused by an awaitable run with :meth:`Application.create_task` or a handler
callback with :attr:`block=False <BaseHandler.block>`.
matches (List[:meth:`re.Match <re.Match.expand>`]): Optional. If the associated update
originated from a :class:`filters.Regex`, this will contain a list of match objects for
@ -143,7 +145,9 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
self.matches: Optional[List[Match[str]]] = None
self.error: Optional[Exception] = None
self.job: Optional["Job[CCT]"] = None
self.coroutine: Optional[Coroutine[Any, Any, Any]] = None
self.coroutine: Optional[
Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]]
] = None
@property
def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]":
@ -275,7 +279,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
error: Exception,
application: "Application[BT, CCT, UD, CD, BD, Any]",
job: "Job[Any]" = None,
coroutine: Coroutine[Any, Any, Any] = None,
coroutine: Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] = None,
) -> "CCT":
"""
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
@ -295,13 +299,15 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
job (:class:`telegram.ext.Job`, optional): The job associated with the error.
.. versionadded:: 20.0
coroutine (:term:`coroutine function`, optional): The coroutine function associated
coroutine (:term:`awaitable`, optional): The awaitable associated
with this error if the error was caused by a coroutine run with
:meth:`Application.create_task` or a handler callback with
:attr:`block=False <BaseHandler.block>`.
.. versionadded:: 20.0
.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
Returns:
:class:`telegram.ext.CallbackContext`
"""

View file

@ -1287,6 +1287,26 @@ class TestApplication:
await asyncio.sleep(0.05)
assert stop_task.done()
async def test_create_task_awaiting_future(self, app):
async def callback():
await asyncio.sleep(0.01)
return 42
# `asyncio.gather` returns an `asyncio.Future` and not an
# `asyncio.Task`
out = await app.create_task(asyncio.gather(callback()))
assert out == [42]
async def test_create_task_awaiting_generator(self, app):
event = asyncio.Event()
def gen():
yield
event.set()
await app.create_task(gen())
assert event.is_set()
async def test_no_concurrent_updates(self, app):
queue = asyncio.Queue()
event_1 = asyncio.Event()