Enhance Application.create_task (#3543)

This commit is contained in:
sam-mosleh 2023-02-20 21:53:27 +03:00 committed by GitHub
parent ee6c8a5995
commit c6b6b0a370
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 16 deletions

View file

@ -100,6 +100,7 @@ The following wonderful people contributed directly or indirectly to this projec
- `Riko Naka <https://github.com/rikonaka>`_ - `Riko Naka <https://github.com/rikonaka>`_
- `Rizlas <https://github.com/rizlas>`_ - `Rizlas <https://github.com/rizlas>`_
- `Sahil Sharma <https://github.com/sahilsharma811>`_ - `Sahil Sharma <https://github.com/sahilsharma811>`_
- `Sam Mosleh <https://github.com/sam-mosleh>`_
- `Sascha <https://github.com/saschalalala>`_ - `Sascha <https://github.com/saschalalala>`_
- `Shelomentsev D <https://github.com/shelomentsevd>`_ - `Shelomentsev D <https://github.com/shelomentsevd>`_
- `Shivam Saini <https://github.com/shivamsn97>`_ - `Shivam Saini <https://github.com/shivamsn97>`_

View file

@ -31,10 +31,12 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncContextManager, AsyncContextManager,
Awaitable,
Callable, Callable,
Coroutine, Coroutine,
DefaultDict, DefaultDict,
Dict, Dict,
Generator,
Generic, Generic,
List, List,
Mapping, Mapping,
@ -71,7 +73,6 @@ if TYPE_CHECKING:
DEFAULT_GROUP: int = 0 DEFAULT_GROUP: int = 0
_AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name _AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name
_RT = TypeVar("_RT")
_STOP_SIGNAL = object() _STOP_SIGNAL = object()
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -934,7 +935,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
loop.close() loop.close()
def create_task( def create_task(
self, coroutine: Coroutine[Any, Any, RT], update: object = None self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
) -> "asyncio.Task[RT]": ) -> "asyncio.Task[RT]":
"""Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by
the :paramref:`coroutine` with :meth:`process_error`. the :paramref:`coroutine` with :meth:`process_error`.
@ -948,7 +951,10 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
.. seealso:: :wiki:`Concurrency` .. seealso:: :wiki:`Concurrency`
Args: Args:
coroutine (:term:`coroutine function`): The coroutine to run as task. coroutine (:term:`awaitable`): The awaitable to run as task.
.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
update (:obj:`object`, optional): If set, will be passed to :meth:`process_error` update (:obj:`object`, optional): If set, will be passed to :meth:`process_error`
as additional information for the error handlers. Moreover, the corresponding as additional information for the error handlers. Moreover, the corresponding
:attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of :attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of
@ -960,13 +966,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
return self.__create_task(coroutine=coroutine, update=update) return self.__create_task(coroutine=coroutine, update=update)
def __create_task( def __create_task(
self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False self,
) -> asyncio.Task: coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
is_error_handler: bool = False,
) -> "asyncio.Task[RT]":
# Unfortunately, we can't know if `coroutine` runs one of the error handler functions # Unfortunately, we can't know if `coroutine` runs one of the error handler functions
# but by passing `is_error_handler=True` from `process_error`, we can make sure that we # but by passing `is_error_handler=True` from `process_error`, we can make sure that we
# get at most one recursion of the user calls `create_task` manually with an error handler # get at most one recursion of the user calls `create_task` manually with an error handler
# function # function
task = asyncio.create_task( task: "asyncio.Task[RT]" = asyncio.create_task(
self.__create_task_callback( self.__create_task_callback(
coroutine=coroutine, update=update, is_error_handler=is_error_handler coroutine=coroutine, update=update, is_error_handler=is_error_handler
) )
@ -995,11 +1004,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
async def __create_task_callback( async def __create_task_callback(
self, self,
coroutine: Coroutine[Any, Any, _RT], coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None, update: object = None,
is_error_handler: bool = False, is_error_handler: bool = False,
) -> _RT: ) -> RT:
try: try:
if isinstance(coroutine, Generator):
return await asyncio.create_task(coroutine)
return await coroutine return await coroutine
except asyncio.CancelledError as cancel: except asyncio.CancelledError as cancel:
# TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this # TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this
@ -1562,7 +1573,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
update: Optional[object], update: Optional[object],
error: Exception, error: Exception,
job: "Job[CCT]" = None, job: "Job[CCT]" = None,
coroutine: Coroutine[Any, Any, Any] = None, coroutine: Union[
Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]
] = None,
) -> bool: ) -> bool:
"""Processes an error by passing it to all error handlers registered with """Processes an error by passing it to all error handlers registered with
:meth:`add_error_handler`. If one of the error handlers raises :meth:`add_error_handler`. If one of the error handlers raises

View file

@ -20,14 +20,16 @@
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Coroutine, Awaitable,
Dict, Dict,
Generator,
Generic, Generic,
List, List,
Match, Match,
NoReturn, NoReturn,
Optional, Optional,
Type, Type,
Union,
) )
from telegram._callbackquery import CallbackQuery from telegram._callbackquery import CallbackQuery
@ -37,7 +39,7 @@ from telegram.ext._extbot import ExtBot
from telegram.ext._utils.types import BD, BT, CD, UD from telegram.ext._utils.types import BD, BT, CD, UD
if TYPE_CHECKING: if TYPE_CHECKING:
from asyncio import Queue from asyncio import Future, Queue
from telegram.ext import Application, Job, JobQueue # noqa: F401 from telegram.ext import Application, Job, JobQueue # noqa: F401
from telegram.ext._utils.types import CCT from telegram.ext._utils.types import CCT
@ -96,8 +98,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
.. versionadded:: 20.0 .. versionadded:: 20.0
Attributes: Attributes:
coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the coroutine (:term:`awaitable`): Optional. Only present in error handlers if the
error was caused by a coroutine run with :meth:`Application.create_task` or a handler error was caused by an awaitable run with :meth:`Application.create_task` or a handler
callback with :attr:`block=False <BaseHandler.block>`. callback with :attr:`block=False <BaseHandler.block>`.
matches (List[:meth:`re.Match <re.Match.expand>`]): Optional. If the associated update matches (List[:meth:`re.Match <re.Match.expand>`]): Optional. If the associated update
originated from a :class:`filters.Regex`, this will contain a list of match objects for originated from a :class:`filters.Regex`, this will contain a list of match objects for
@ -143,7 +145,9 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
self.matches: Optional[List[Match[str]]] = None self.matches: Optional[List[Match[str]]] = None
self.error: Optional[Exception] = None self.error: Optional[Exception] = None
self.job: Optional["Job[CCT]"] = None self.job: Optional["Job[CCT]"] = None
self.coroutine: Optional[Coroutine[Any, Any, Any]] = None self.coroutine: Optional[
Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]]
] = None
@property @property
def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]": def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]":
@ -275,7 +279,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
error: Exception, error: Exception,
application: "Application[BT, CCT, UD, CD, BD, Any]", application: "Application[BT, CCT, UD, CD, BD, Any]",
job: "Job[Any]" = None, job: "Job[Any]" = None,
coroutine: Coroutine[Any, Any, Any] = None, coroutine: Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] = None,
) -> "CCT": ) -> "CCT":
""" """
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
@ -295,13 +299,15 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
job (:class:`telegram.ext.Job`, optional): The job associated with the error. job (:class:`telegram.ext.Job`, optional): The job associated with the error.
.. versionadded:: 20.0 .. versionadded:: 20.0
coroutine (:term:`coroutine function`, optional): The coroutine function associated coroutine (:term:`awaitable`, optional): The awaitable associated
with this error if the error was caused by a coroutine run with with this error if the error was caused by a coroutine run with
:meth:`Application.create_task` or a handler callback with :meth:`Application.create_task` or a handler callback with
:attr:`block=False <BaseHandler.block>`. :attr:`block=False <BaseHandler.block>`.
.. versionadded:: 20.0 .. versionadded:: 20.0
.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
Returns: Returns:
:class:`telegram.ext.CallbackContext` :class:`telegram.ext.CallbackContext`
""" """

View file

@ -1287,6 +1287,26 @@ class TestApplication:
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
assert stop_task.done() assert stop_task.done()
async def test_create_task_awaiting_future(self, app):
async def callback():
await asyncio.sleep(0.01)
return 42
# `asyncio.gather` returns an `asyncio.Future` and not an
# `asyncio.Task`
out = await app.create_task(asyncio.gather(callback()))
assert out == [42]
async def test_create_task_awaiting_generator(self, app):
event = asyncio.Event()
def gen():
yield
event.set()
await app.create_task(gen())
assert event.is_set()
async def test_no_concurrent_updates(self, app): async def test_no_concurrent_updates(self, app):
queue = asyncio.Queue() queue = asyncio.Queue()
event_1 = asyncio.Event() event_1 = asyncio.Event()