mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 22:45:09 +01:00
Enhance Application.create_task
(#3543)
This commit is contained in:
parent
ee6c8a5995
commit
c6b6b0a370
4 changed files with 56 additions and 16 deletions
|
@ -100,6 +100,7 @@ The following wonderful people contributed directly or indirectly to this projec
|
|||
- `Riko Naka <https://github.com/rikonaka>`_
|
||||
- `Rizlas <https://github.com/rizlas>`_
|
||||
- `Sahil Sharma <https://github.com/sahilsharma811>`_
|
||||
- `Sam Mosleh <https://github.com/sam-mosleh>`_
|
||||
- `Sascha <https://github.com/saschalalala>`_
|
||||
- `Shelomentsev D <https://github.com/shelomentsevd>`_
|
||||
- `Shivam Saini <https://github.com/shivamsn97>`_
|
||||
|
|
|
@ -31,10 +31,12 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
|
@ -71,7 +73,6 @@ if TYPE_CHECKING:
|
|||
DEFAULT_GROUP: int = 0
|
||||
|
||||
_AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name
|
||||
_RT = TypeVar("_RT")
|
||||
_STOP_SIGNAL = object()
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -934,7 +935,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
loop.close()
|
||||
|
||||
def create_task(
|
||||
self, coroutine: Coroutine[Any, Any, RT], update: object = None
|
||||
self,
|
||||
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
|
||||
update: object = None,
|
||||
) -> "asyncio.Task[RT]":
|
||||
"""Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by
|
||||
the :paramref:`coroutine` with :meth:`process_error`.
|
||||
|
@ -948,7 +951,10 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
.. seealso:: :wiki:`Concurrency`
|
||||
|
||||
Args:
|
||||
coroutine (:term:`coroutine function`): The coroutine to run as task.
|
||||
coroutine (:term:`awaitable`): The awaitable to run as task.
|
||||
|
||||
.. versionchanged:: 20.2
|
||||
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
|
||||
update (:obj:`object`, optional): If set, will be passed to :meth:`process_error`
|
||||
as additional information for the error handlers. Moreover, the corresponding
|
||||
:attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of
|
||||
|
@ -960,13 +966,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
return self.__create_task(coroutine=coroutine, update=update)
|
||||
|
||||
def __create_task(
|
||||
self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False
|
||||
) -> asyncio.Task:
|
||||
self,
|
||||
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
|
||||
update: object = None,
|
||||
is_error_handler: bool = False,
|
||||
) -> "asyncio.Task[RT]":
|
||||
# Unfortunately, we can't know if `coroutine` runs one of the error handler functions
|
||||
# but by passing `is_error_handler=True` from `process_error`, we can make sure that we
|
||||
# get at most one recursion of the user calls `create_task` manually with an error handler
|
||||
# function
|
||||
task = asyncio.create_task(
|
||||
task: "asyncio.Task[RT]" = asyncio.create_task(
|
||||
self.__create_task_callback(
|
||||
coroutine=coroutine, update=update, is_error_handler=is_error_handler
|
||||
)
|
||||
|
@ -995,11 +1004,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
|
||||
async def __create_task_callback(
|
||||
self,
|
||||
coroutine: Coroutine[Any, Any, _RT],
|
||||
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
|
||||
update: object = None,
|
||||
is_error_handler: bool = False,
|
||||
) -> _RT:
|
||||
) -> RT:
|
||||
try:
|
||||
if isinstance(coroutine, Generator):
|
||||
return await asyncio.create_task(coroutine)
|
||||
return await coroutine
|
||||
except asyncio.CancelledError as cancel:
|
||||
# TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this
|
||||
|
@ -1562,7 +1573,9 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
update: Optional[object],
|
||||
error: Exception,
|
||||
job: "Job[CCT]" = None,
|
||||
coroutine: Coroutine[Any, Any, Any] = None,
|
||||
coroutine: Union[
|
||||
Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]
|
||||
] = None,
|
||||
) -> bool:
|
||||
"""Processes an error by passing it to all error handlers registered with
|
||||
:meth:`add_error_handler`. If one of the error handlers raises
|
||||
|
|
|
@ -20,14 +20,16 @@
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
List,
|
||||
Match,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from telegram._callbackquery import CallbackQuery
|
||||
|
@ -37,7 +39,7 @@ from telegram.ext._extbot import ExtBot
|
|||
from telegram.ext._utils.types import BD, BT, CD, UD
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import Queue
|
||||
from asyncio import Future, Queue
|
||||
|
||||
from telegram.ext import Application, Job, JobQueue # noqa: F401
|
||||
from telegram.ext._utils.types import CCT
|
||||
|
@ -96,8 +98,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
|
|||
|
||||
.. versionadded:: 20.0
|
||||
Attributes:
|
||||
coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the
|
||||
error was caused by a coroutine run with :meth:`Application.create_task` or a handler
|
||||
coroutine (:term:`awaitable`): Optional. Only present in error handlers if the
|
||||
error was caused by an awaitable run with :meth:`Application.create_task` or a handler
|
||||
callback with :attr:`block=False <BaseHandler.block>`.
|
||||
matches (List[:meth:`re.Match <re.Match.expand>`]): Optional. If the associated update
|
||||
originated from a :class:`filters.Regex`, this will contain a list of match objects for
|
||||
|
@ -143,7 +145,9 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
|
|||
self.matches: Optional[List[Match[str]]] = None
|
||||
self.error: Optional[Exception] = None
|
||||
self.job: Optional["Job[CCT]"] = None
|
||||
self.coroutine: Optional[Coroutine[Any, Any, Any]] = None
|
||||
self.coroutine: Optional[
|
||||
Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]]
|
||||
] = None
|
||||
|
||||
@property
|
||||
def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]":
|
||||
|
@ -275,7 +279,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
|
|||
error: Exception,
|
||||
application: "Application[BT, CCT, UD, CD, BD, Any]",
|
||||
job: "Job[Any]" = None,
|
||||
coroutine: Coroutine[Any, Any, Any] = None,
|
||||
coroutine: Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] = None,
|
||||
) -> "CCT":
|
||||
"""
|
||||
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
|
||||
|
@ -295,13 +299,15 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
|
|||
job (:class:`telegram.ext.Job`, optional): The job associated with the error.
|
||||
|
||||
.. versionadded:: 20.0
|
||||
coroutine (:term:`coroutine function`, optional): The coroutine function associated
|
||||
coroutine (:term:`awaitable`, optional): The awaitable associated
|
||||
with this error if the error was caused by a coroutine run with
|
||||
:meth:`Application.create_task` or a handler callback with
|
||||
:attr:`block=False <BaseHandler.block>`.
|
||||
|
||||
.. versionadded:: 20.0
|
||||
|
||||
.. versionchanged:: 20.2
|
||||
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
|
||||
Returns:
|
||||
:class:`telegram.ext.CallbackContext`
|
||||
"""
|
||||
|
|
|
@ -1287,6 +1287,26 @@ class TestApplication:
|
|||
await asyncio.sleep(0.05)
|
||||
assert stop_task.done()
|
||||
|
||||
async def test_create_task_awaiting_future(self, app):
|
||||
async def callback():
|
||||
await asyncio.sleep(0.01)
|
||||
return 42
|
||||
|
||||
# `asyncio.gather` returns an `asyncio.Future` and not an
|
||||
# `asyncio.Task`
|
||||
out = await app.create_task(asyncio.gather(callback()))
|
||||
assert out == [42]
|
||||
|
||||
async def test_create_task_awaiting_generator(self, app):
|
||||
event = asyncio.Event()
|
||||
|
||||
def gen():
|
||||
yield
|
||||
event.set()
|
||||
|
||||
await app.create_task(gen())
|
||||
assert event.is_set()
|
||||
|
||||
async def test_no_concurrent_updates(self, app):
|
||||
queue = asyncio.Queue()
|
||||
event_1 = asyncio.Event()
|
||||
|
|
Loading…
Reference in a new issue