Add Defaults.run_async (#2210)

* Add Defaults.run_async support

Signed-off-by: starry69 <starry369126@outlook.com>

* Address some requested changes.

Signed-off-by: starry69 <starry369126@outlook.com>

* Add tests for defaults.run_async

Signed-off-by: starry69 <starry369126@outlook.com>

* Fix tests logic & add default value support for dp.add_error_handler

Signed-off-by: starry69 <starry369126@outlook.com>

* Fix tests, with requested changes

Signed-off-by: starry69 <starry369126@outlook.com>

* Add tests for error_handler

Signed-off-by: starry69 <starry369126@outlook.com>

* try to fix pre-commit

Signed-off-by: starry69 <starry369126@outlook.com>

* Enhance tests & address suggested changes

Signed-off-by: starry69 <starry369126@outlook.com>

* Improve docs

Signed-off-by: starry69 <starry369126@outlook.com>
This commit is contained in:
Stɑrry Shivɑm 2020-11-18 02:01:01 +05:30 committed by GitHub
parent 8d9bb26cca
commit 425716f966
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 101 additions and 16 deletions

View file

@ -34,6 +34,7 @@ from typing import (
from telegram import Update
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -130,7 +131,7 @@ class CallbackQueryHandler(Handler):
pass_groupdict: bool = False,
pass_user_data: bool = False,
pass_chat_data: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,

View file

@ -25,6 +25,7 @@ from telegram import MessageEntity, Update
from telegram.ext import BaseFilter, Filters
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import HandlerArg, SLT
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -141,7 +142,7 @@ class CommandHandler(Handler):
pass_job_queue: bool = False,
pass_user_data: bool = False,
pass_chat_data: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,
@ -350,7 +351,7 @@ class PrefixHandler(CommandHandler):
pass_job_queue: bool = False,
pass_user_data: bool = False,
pass_chat_data: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
self._prefix: List[str] = list()

View file

@ -43,6 +43,9 @@ class Defaults:
be ignored. Default: :obj:`True` in group chats and :obj:`False` in private chats.
tzinfo (:obj:`tzinfo`): A timezone to be used for all date(time) objects appearing
throughout PTB.
run_async (:obj:`bool`): Optional. Default setting for the ``run_async`` parameter of
handlers and error handlers registered through :meth:`Dispatcher.add_handler` and
:meth:`Dispatcher.add_error_handler`.
Parameters:
parse_mode (:obj:`str`, optional): Send Markdown or HTML, if you want Telegram apps to show
@ -61,6 +64,9 @@ class Defaults:
appearing throughout PTB, i.e. if a timezone naive date(time) object is passed
somewhere, it will be assumed to be in ``tzinfo``. Must be a timezone provided by the
``pytz`` module. Defaults to UTC.
run_async (:obj:`bool`, optional): Default setting for the ``run_async`` parameter of
handlers and error handlers registered through :meth:`Dispatcher.add_handler` and
:meth:`Dispatcher.add_error_handler`. Defaults to :obj:`False`.
"""
def __init__(
@ -73,6 +79,7 @@ class Defaults:
timeout: Union[float, DefaultValue] = DEFAULT_NONE,
quote: bool = None,
tzinfo: pytz.BaseTzInfo = pytz.utc,
run_async: bool = False,
):
self._parse_mode = parse_mode
self._disable_notification = disable_notification
@ -80,6 +87,7 @@ class Defaults:
self._timeout = timeout
self._quote = quote
self._tzinfo = tzinfo
self._run_async = run_async
@property
def parse_mode(self) -> Optional[str]:
@ -147,6 +155,17 @@ class Defaults:
"not have any effect."
)
@property
def run_async(self) -> Optional[bool]:
return self._run_async
@run_async.setter
def run_async(self, value: Any) -> NoReturn:
raise AttributeError(
"You can not assign a new value to defaults after because it would "
"not have any effect."
)
def __hash__(self) -> int:
return hash(
(
@ -156,6 +175,7 @@ class Defaults:
self._timeout,
self._quote,
self._tzinfo,
self._run_async,
)
)

View file

@ -36,6 +36,7 @@ from telegram.ext.handler import Handler
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.promise import Promise
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
if TYPE_CHECKING:
from telegram import Bot
@ -191,7 +192,7 @@ class Dispatcher:
"""Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]: Holds the handlers per group."""
self.groups: List[int] = []
"""List[:obj:`int`]: A list with all groups."""
self.error_handlers: Dict[Callable, bool] = {}
self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {}
"""Dict[:obj:`callable`, :obj:`bool`]: A dict, where the keys are error handlers and the
values indicate whether they are to be run asynchronously."""
@ -588,7 +589,7 @@ class Dispatcher:
def add_error_handler(
self,
callback: Callable[[Any, CallbackContext], None],
run_async: bool = False, # pylint: disable=W0621
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, # pylint: disable=W0621
) -> None:
"""Registers an error handler in the Dispatcher. This handler will receive every error
which happens in your bot.
@ -616,6 +617,11 @@ class Dispatcher:
if callback in self.error_handlers:
self.logger.debug('The callback is already registered as an error handler. Ignoring.')
return
if run_async is DEFAULT_FALSE and self.bot.defaults:
if self.bot.defaults.run_async:
run_async = True
self.error_handlers[callback] = run_async
def remove_error_handler(self, callback: Callable[[Any, CallbackContext], None]) -> None:

View file

@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, Union
from telegram import Update
from telegram.utils.promise import Promise
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher
@ -96,7 +97,7 @@ class Handler(ABC):
pass_job_queue: bool = False,
pass_user_data: bool = False,
pass_chat_data: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
self.callback: Callable[[HandlerArg, 'CallbackContext'], RT] = callback
self.pass_update_queue = pass_update_queue
@ -143,14 +144,19 @@ class Handler(ABC):
the dispatcher.
"""
run_async = self.run_async
if self.run_async is DEFAULT_FALSE and dispatcher.bot.defaults:
if dispatcher.bot.defaults.run_async:
run_async = True
if context:
self.collect_additional_context(context, update, dispatcher, check_result)
if self.run_async:
if run_async:
return dispatcher.run_async(self.callback, update, context, update=update)
return self.callback(update, context)
optional_args = self.collect_optional_args(dispatcher, update, check_result)
if self.run_async:
if run_async:
return dispatcher.run_async(
self.callback, dispatcher.bot, update, update=update, **optional_args
)

View file

@ -33,6 +33,7 @@ from typing import (
from telegram import Update
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -129,7 +130,7 @@ class InlineQueryHandler(Handler):
pass_groupdict: bool = False,
pass_user_data: bool = False,
pass_chat_data: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,

View file

@ -25,6 +25,7 @@ from telegram import Update
from telegram.ext import BaseFilter, Filters
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -131,7 +132,7 @@ class MessageHandler(Handler):
message_updates: bool = None,
channel_post_updates: bool = None,
edited_updates: bool = None,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(

View file

@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Pattern, TypeVa
from telegram.ext import Filters, MessageHandler
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher
@ -121,7 +122,7 @@ class RegexHandler(MessageHandler):
message_updates: bool = True,
channel_post_updates: bool = False,
edited_updates: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
warnings.warn(
'RegexHandler is deprecated. See https://git.io/fxJuV for more info',

View file

@ -18,9 +18,10 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the StringCommandHandler class."""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, Union
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -89,7 +90,7 @@ class StringCommandHandler(Handler):
pass_args: bool = False,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,

View file

@ -22,6 +22,7 @@ import re
from typing import TYPE_CHECKING, Any, Callable, Dict, Match, Optional, Pattern, TypeVar, Union
from telegram.utils.types import HandlerArg
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -99,7 +100,7 @@ class StringRegexHandler(Handler):
pass_groupdict: bool = False,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,

View file

@ -18,7 +18,8 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the TypeHandler class."""
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
@ -80,7 +81,7 @@ class TypeHandler(Handler):
strict: bool = False,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
run_async: bool = False,
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE,
):
super().__init__(
callback,

View file

@ -445,3 +445,6 @@ class DefaultValue:
DEFAULT_NONE: DefaultValue = DefaultValue(None)
""":class:`DefaultValue`: Default `None`"""
DEFAULT_FALSE: DefaultValue = DefaultValue(False)
""":class:`DefaultValue`: Default `False`"""

View file

@ -39,6 +39,8 @@ class TestDefault:
defaults.quote = True
with pytest.raises(AttributeError):
defaults.tzinfo = True
with pytest.raises(AttributeError):
defaults.run_async = True
def test_equality(self):
a = Defaults(parse_mode='HTML', quote=True)

View file

@ -27,6 +27,7 @@ from telegram import TelegramError, Message, User, Chat, Update, Bot, MessageEnt
from telegram.ext import (
MessageHandler,
Filters,
Defaults,
CommandHandler,
CallbackContext,
JobQueue,
@ -174,6 +175,45 @@ class TestDispatcher:
assert self.count == 1
@pytest.mark.parametrize(['run_async', 'expected_output'], [(True, 5), (False, 0)])
def test_default_run_async_error_handler(self, dp, monkeypatch, run_async, expected_output):
def mock_async_err_handler(*args, **kwargs):
self.count = 5
# set defaults value to dp.bot
dp.bot.defaults = Defaults(run_async=run_async)
try:
dp.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
dp.add_error_handler(self.error_handler)
monkeypatch.setattr(dp, 'run_async', mock_async_err_handler)
dp.process_update(self.message_update)
assert self.count == expected_output
finally:
# reset dp.bot.defaults values
dp.bot.defaults = None
@pytest.mark.parametrize(
['run_async', 'expected_output'], [(True, 'running async'), (False, None)]
)
def test_default_run_async(self, monkeypatch, dp, run_async, expected_output):
def mock_run_async(*args, **kwargs):
self.received = 'running async'
# set defaults value to dp.bot
dp.bot.defaults = Defaults(run_async=run_async)
try:
dp.add_handler(MessageHandler(Filters.all, lambda u, c: None))
monkeypatch.setattr(dp, 'run_async', mock_run_async)
dp.process_update(self.message_update)
assert self.received == expected_output
finally:
# reset defaults value
dp.bot.defaults = None
def test_run_async_multiple(self, bot, dp, dp2):
def get_dispatcher_name(q):
q.put(current_thread().name)