diff --git a/AUTHORS.rst b/AUTHORS.rst index 1f806057b..c3aa909cc 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -100,6 +100,7 @@ The following wonderful people contributed directly or indirectly to this projec - `Riko Naka `_ - `Rizlas `_ - `Sahil Sharma `_ +- `Sam Mosleh `_ - `Sascha `_ - `Shelomentsev D `_ - `Shivam Saini `_ diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 7d987e545..ddb9c0d4e 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -31,10 +31,12 @@ from typing import ( TYPE_CHECKING, Any, AsyncContextManager, + Awaitable, Callable, Coroutine, DefaultDict, Dict, + Generator, Generic, List, Mapping, @@ -71,7 +73,6 @@ if TYPE_CHECKING: DEFAULT_GROUP: int = 0 _AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name -_RT = TypeVar("_RT") _STOP_SIGNAL = object() _logger = logging.getLogger(__name__) @@ -934,7 +935,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica loop.close() def create_task( - self, coroutine: Coroutine[Any, Any, RT], update: object = None + self, + coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]], + update: object = None, ) -> "asyncio.Task[RT]": """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by the :paramref:`coroutine` with :meth:`process_error`. @@ -948,7 +951,10 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica .. seealso:: :wiki:`Concurrency` Args: - coroutine (:term:`coroutine function`): The coroutine to run as task. + coroutine (:term:`awaitable`): The awaitable to run as task. + + .. versionchanged:: 20.2 + Accepts :class:`asyncio.Future` and generator-based coroutine functions. update (:obj:`object`, optional): If set, will be passed to :meth:`process_error` as additional information for the error handlers. Moreover, the corresponding :attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of @@ -960,13 +966,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica return self.__create_task(coroutine=coroutine, update=update) def __create_task( - self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False - ) -> asyncio.Task: + self, + coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]], + update: object = None, + is_error_handler: bool = False, + ) -> "asyncio.Task[RT]": # Unfortunately, we can't know if `coroutine` runs one of the error handler functions # but by passing `is_error_handler=True` from `process_error`, we can make sure that we # get at most one recursion of the user calls `create_task` manually with an error handler # function - task = asyncio.create_task( + task: "asyncio.Task[RT]" = asyncio.create_task( self.__create_task_callback( coroutine=coroutine, update=update, is_error_handler=is_error_handler ) @@ -995,11 +1004,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica async def __create_task_callback( self, - coroutine: Coroutine[Any, Any, _RT], + coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]], update: object = None, is_error_handler: bool = False, - ) -> _RT: + ) -> RT: try: + if isinstance(coroutine, Generator): + return await asyncio.create_task(coroutine) return await coroutine except asyncio.CancelledError as cancel: # TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this @@ -1562,7 +1573,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica update: Optional[object], error: Exception, job: "Job[CCT]" = None, - coroutine: Coroutine[Any, Any, Any] = None, + coroutine: Union[ + Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT] + ] = None, ) -> bool: """Processes an error by passing it to all error handlers registered with :meth:`add_error_handler`. If one of the error handlers raises diff --git a/telegram/ext/_callbackcontext.py b/telegram/ext/_callbackcontext.py index b4834000f..b4b4f1b56 100644 --- a/telegram/ext/_callbackcontext.py +++ b/telegram/ext/_callbackcontext.py @@ -20,14 +20,16 @@ from typing import ( TYPE_CHECKING, Any, - Coroutine, + Awaitable, Dict, + Generator, Generic, List, Match, NoReturn, Optional, Type, + Union, ) from telegram._callbackquery import CallbackQuery @@ -37,7 +39,7 @@ from telegram.ext._extbot import ExtBot from telegram.ext._utils.types import BD, BT, CD, UD if TYPE_CHECKING: - from asyncio import Queue + from asyncio import Future, Queue from telegram.ext import Application, Job, JobQueue # noqa: F401 from telegram.ext._utils.types import CCT @@ -96,8 +98,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]): .. versionadded:: 20.0 Attributes: - coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the - error was caused by a coroutine run with :meth:`Application.create_task` or a handler + coroutine (:term:`awaitable`): Optional. Only present in error handlers if the + error was caused by an awaitable run with :meth:`Application.create_task` or a handler callback with :attr:`block=False `. matches (List[:meth:`re.Match `]): Optional. If the associated update originated from a :class:`filters.Regex`, this will contain a list of match objects for @@ -143,7 +145,9 @@ class CallbackContext(Generic[BT, UD, CD, BD]): self.matches: Optional[List[Match[str]]] = None self.error: Optional[Exception] = None self.job: Optional["Job[CCT]"] = None - self.coroutine: Optional[Coroutine[Any, Any, Any]] = None + self.coroutine: Optional[ + Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] + ] = None @property def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]": @@ -275,7 +279,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]): error: Exception, application: "Application[BT, CCT, UD, CD, BD, Any]", job: "Job[Any]" = None, - coroutine: Coroutine[Any, Any, Any] = None, + coroutine: Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] = None, ) -> "CCT": """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error @@ -295,13 +299,15 @@ class CallbackContext(Generic[BT, UD, CD, BD]): job (:class:`telegram.ext.Job`, optional): The job associated with the error. .. versionadded:: 20.0 - coroutine (:term:`coroutine function`, optional): The coroutine function associated + coroutine (:term:`awaitable`, optional): The awaitable associated with this error if the error was caused by a coroutine run with :meth:`Application.create_task` or a handler callback with :attr:`block=False `. .. versionadded:: 20.0 + .. versionchanged:: 20.2 + Accepts :class:`asyncio.Future` and generator-based coroutine functions. Returns: :class:`telegram.ext.CallbackContext` """ diff --git a/tests/test_application.py b/tests/test_application.py index b80c6a334..16cde5fcc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1287,6 +1287,26 @@ class TestApplication: await asyncio.sleep(0.05) assert stop_task.done() + async def test_create_task_awaiting_future(self, app): + async def callback(): + await asyncio.sleep(0.01) + return 42 + + # `asyncio.gather` returns an `asyncio.Future` and not an + # `asyncio.Task` + out = await app.create_task(asyncio.gather(callback())) + assert out == [42] + + async def test_create_task_awaiting_generator(self, app): + event = asyncio.Event() + + def gen(): + yield + event.set() + + await app.create_task(gen()) + assert event.is_set() + async def test_no_concurrent_updates(self, app): queue = asyncio.Queue() event_1 = asyncio.Event()