diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ba3c919ef..f9dbe6885 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -60,7 +60,7 @@ jobs: shell: bash --noprofile --norc {0} - name: Submit coverage - uses: codecov/codecov-action@v1.0.13 + uses: codecov/codecov-action@v1 with: env_vars: OS,PYTHON name: ${{ matrix.os }}-${{ matrix.python-version }} @@ -79,7 +79,7 @@ jobs: run: git submodule update --init --recursive - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -108,7 +108,7 @@ jobs: run: git submodule update --init --recursive - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/docs/source/telegram.ext.contexttypes.rst b/docs/source/telegram.ext.contexttypes.rst new file mode 100644 index 000000000..d0cc0a29a --- /dev/null +++ b/docs/source/telegram.ext.contexttypes.rst @@ -0,0 +1,8 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/contexttypes.py + +telegram.ext.ContextTypes +========================= + +.. autoclass:: telegram.ext.ContextTypes + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst index 9bd855a87..316950446 100644 --- a/docs/source/telegram.ext.rst +++ b/docs/source/telegram.ext.rst @@ -7,11 +7,12 @@ telegram.ext package telegram.ext.dispatcher telegram.ext.dispatcherhandlerstop telegram.ext.callbackcontext - telegram.ext.defaults telegram.ext.job telegram.ext.jobqueue telegram.ext.messagequeue telegram.ext.delayqueue + telegram.ext.contexttypes + telegram.ext.defaults Handlers -------- @@ -51,4 +52,5 @@ utils .. toctree:: - telegram.ext.utils.promise \ No newline at end of file + telegram.ext.utils.promise + telegram.ext.utils.types \ No newline at end of file diff --git a/docs/source/telegram.ext.utils.types.rst b/docs/source/telegram.ext.utils.types.rst new file mode 100644 index 000000000..5c501ecf8 --- /dev/null +++ b/docs/source/telegram.ext.utils.types.rst @@ -0,0 +1,8 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/utils/types.py + +telegram.ext.utils.types Module +================================ + +.. automodule:: telegram.ext.utils.types + :members: + :show-inheritance: diff --git a/examples/README.md b/examples/README.md index 5b05c53ef..7d8f19225 100644 --- a/examples/README.md +++ b/examples/README.md @@ -47,7 +47,10 @@ A basic example of a bot that can accept payments. Don't forget to enable and co A basic example on how to set up a custom error handler. ### [`chatmemberbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/chatmemberbot.py) -A basic example on how `(my_)chat_member` updates can be used. +A basic example on how `(my_)chat_member` updates can be used. + +### [`contexttypesbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/contexttypesbot.py) +This example showcases how `telegram.ext.ContextTypes` can be used to customize the `context` argument of handler and job callbacks. ## Pure API The [`rawapibot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/rawapibot.py) example uses only the pure, "bare-metal" API wrapper. diff --git a/examples/contexttypesbot.py b/examples/contexttypesbot.py new file mode 100644 index 000000000..cfe485a61 --- /dev/null +++ b/examples/contexttypesbot.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python +# pylint: disable=C0116,W0613 +# This program is dedicated to the public domain under the CC0 license. + +""" +Simple Bot to showcase `telegram.ext.ContextTypes`. + +Usage: +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +from collections import defaultdict +from typing import DefaultDict, Optional, Set + +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, ParseMode +from telegram.ext import ( + Updater, + CommandHandler, + CallbackContext, + ContextTypes, + CallbackQueryHandler, + TypeHandler, + Dispatcher, +) + + +class ChatData: + """Custom class for chat_data. Here we store data per message.""" + + def __init__(self) -> None: + self.clicks_per_message: DefaultDict[int, int] = defaultdict(int) + + +# The [dict, ChatData, dict] is for type checkers like mypy +class CustomContext(CallbackContext[dict, ChatData, dict]): + """Custom class for context.""" + + def __init__(self, dispatcher: Dispatcher): + super().__init__(dispatcher=dispatcher) + self._message_id: Optional[int] = None + + @property + def bot_user_ids(self) -> Set[int]: + """Custom shortcut to access a value stored in the bot_data dict""" + return self.bot_data.setdefault('user_ids', set()) + + @property + def message_clicks(self) -> Optional[int]: + """Access the number of clicks for the message this context object was built for.""" + if self._message_id: + return self.chat_data.clicks_per_message[self._message_id] + return None + + @message_clicks.setter + def message_clicks(self, value: int) -> None: + """Allow to change the count""" + if not self._message_id: + raise RuntimeError('There is no message associated with this context obejct.') + self.chat_data.clicks_per_message[self._message_id] = value + + @classmethod + def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CustomContext': + """Override from_update to set _message_id.""" + # Make sure to call super() + context = super().from_update(update, dispatcher) + + if context.chat_data and isinstance(update, Update) and update.effective_message: + context._message_id = update.effective_message.message_id # pylint: disable=W0212 + + # Remember to return the object + return context + + +def start(update: Update, context: CustomContext) -> None: + """Display a message with a button.""" + update.message.reply_html( + 'This button was clicked 0 times.', + reply_markup=InlineKeyboardMarkup.from_button( + InlineKeyboardButton(text='Click me!', callback_data='button') + ), + ) + + +def count_click(update: Update, context: CustomContext) -> None: + """Update the click count for the message.""" + context.message_clicks += 1 + update.callback_query.answer() + update.effective_message.edit_text( + f'This button was clicked {context.message_clicks} times.', + reply_markup=InlineKeyboardMarkup.from_button( + InlineKeyboardButton(text='Click me!', callback_data='button') + ), + parse_mode=ParseMode.HTML, + ) + + +def print_users(update: Update, context: CustomContext) -> None: + """Show which users have been using this bot.""" + update.message.reply_text( + 'The following user IDs have used this bot: ' + f'{", ".join(map(str, context.bot_user_ids))}' + ) + + +def track_users(update: Update, context: CustomContext) -> None: + """Store the user id of the incoming update, if any.""" + if update.effective_user: + context.bot_user_ids.add(update.effective_user.id) + + +def main() -> None: + """Run the bot.""" + context_types = ContextTypes(context=CustomContext, chat_data=ChatData) + updater = Updater("TOKEN", context_types=context_types) + + dispatcher = updater.dispatcher + # run track_users in its own group to not interfere with the user handlers + dispatcher.add_handler(TypeHandler(Update, track_users), group=-1) + dispatcher.add_handler(CommandHandler("start", start)) + dispatcher.add_handler(CallbackQueryHandler(count_click)) + dispatcher.add_handler(CommandHandler("print_users", print_users)) + + updater.start_polling() + updater.idle() + + +if __name__ == '__main__': + main() diff --git a/setup.cfg b/setup.cfg index dd6d88012..a38acfd36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ omit = [coverage:report] exclude_lines = if TYPE_CHECKING: + ... [mypy] warn_unused_ignores = True diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index f536c2be2..93f561514 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +# pylint: disable=C0413 """Extensions over the Telegram Bot API to facilitate bot making""" from .basepersistence import BasePersistence @@ -23,7 +24,20 @@ from .picklepersistence import PicklePersistence from .dictpersistence import DictPersistence from .handler import Handler from .callbackcontext import CallbackContext +from .contexttypes import ContextTypes from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async + +# https://bugs.python.org/issue41451, fixed on 3.7+, doesn't actually remove slots +# try-except is just here in case the __init__ is called twice (like in the tests) +# this block is also the reason for the pylint-ignore at the top of the file +try: + del Dispatcher.__slots__ # type: ignore[has-type] +except AttributeError as exc: + if str(exc) == '__slots__': + pass + else: + raise exc + from .jobqueue import JobQueue, Job from .updater import Updater from .callbackqueryhandler import CallbackQueryHandler @@ -47,38 +61,39 @@ from .chatmemberhandler import ChatMemberHandler from .defaults import Defaults __all__ = ( - 'Dispatcher', - 'JobQueue', - 'Job', - 'Updater', + 'BaseFilter', + 'BasePersistence', + 'CallbackContext', 'CallbackQueryHandler', + 'ChatMemberHandler', 'ChosenInlineResultHandler', 'CommandHandler', + 'ContextTypes', + 'ConversationHandler', + 'Defaults', + 'DelayQueue', + 'DictPersistence', + 'Dispatcher', + 'DispatcherHandlerStop', + 'Filters', 'Handler', 'InlineQueryHandler', - 'MessageHandler', - 'BaseFilter', + 'Job', + 'JobQueue', 'MessageFilter', - 'UpdateFilter', - 'Filters', + 'MessageHandler', + 'MessageQueue', + 'PicklePersistence', + 'PollAnswerHandler', + 'PollHandler', + 'PreCheckoutQueryHandler', + 'PrefixHandler', 'RegexHandler', + 'ShippingQueryHandler', 'StringCommandHandler', 'StringRegexHandler', 'TypeHandler', - 'ConversationHandler', - 'PreCheckoutQueryHandler', - 'ShippingQueryHandler', - 'MessageQueue', - 'DelayQueue', - 'DispatcherHandlerStop', + 'UpdateFilter', + 'Updater', 'run_async', - 'CallbackContext', - 'BasePersistence', - 'PicklePersistence', - 'DictPersistence', - 'PrefixHandler', - 'PollAnswerHandler', - 'PollHandler', - 'ChatMemberHandler', - 'Defaults', ) diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py index c0c248e5c..94453bec5 100644 --- a/telegram/ext/basepersistence.py +++ b/telegram/ext/basepersistence.py @@ -21,16 +21,17 @@ import warnings from sys import version_info as py_ver from abc import ABC, abstractmethod from copy import copy -from typing import DefaultDict, Dict, Optional, Tuple, cast, ClassVar +from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, DefaultDict from telegram.utils.deprecate import set_new_attribute_deprecated from telegram import Bot from telegram.utils.types import ConversationDict +from telegram.ext.utils.types import UD, CD, BD -class BasePersistence(ABC): +class BasePersistence(Generic[UD, CD, BD], ABC): """Interface class for adding persistence to your bot. Subclass this object for different implementations of a persistent bot. @@ -38,16 +39,20 @@ class BasePersistence(ABC): * :meth:`get_bot_data` * :meth:`update_bot_data` + * :meth:`refresh_bot_data` * :meth:`get_chat_data` * :meth:`update_chat_data` + * :meth:`refresh_chat_data` * :meth:`get_user_data` * :meth:`update_user_data` + * :meth:`refresh_user_data` * :meth:`get_conversations` * :meth:`update_conversation` * :meth:`flush` If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if - ``store_bot_data=False``, you don't need :meth:`get_bot_data` and :meth:`update_bot_data`. + ``store_bot_data=False``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or + :meth:`refresh_bot_data`. Warning: Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and @@ -93,6 +98,10 @@ class BasePersistence(ABC): def __new__( cls, *args: object, **kwargs: object # pylint: disable=W0613 ) -> 'BasePersistence': + """This overrides the get_* and update_* methods to use insert/replace_bot. + That has the side effect that we always pass deepcopied data to those methods, so in + Pickle/DictPersistence we don't have to worry about copying the data again. + """ instance = super().__new__(cls) get_user_data = instance.get_user_data get_chat_data = instance.get_chat_data @@ -101,22 +110,22 @@ class BasePersistence(ABC): update_chat_data = instance.update_chat_data update_bot_data = instance.update_bot_data - def get_user_data_insert_bot() -> DefaultDict[int, Dict[object, object]]: + def get_user_data_insert_bot() -> DefaultDict[int, UD]: return instance.insert_bot(get_user_data()) - def get_chat_data_insert_bot() -> DefaultDict[int, Dict[object, object]]: + def get_chat_data_insert_bot() -> DefaultDict[int, CD]: return instance.insert_bot(get_chat_data()) - def get_bot_data_insert_bot() -> Dict[object, object]: + def get_bot_data_insert_bot() -> BD: return instance.insert_bot(get_bot_data()) - def update_user_data_replace_bot(user_id: int, data: Dict) -> None: + def update_user_data_replace_bot(user_id: int, data: UD) -> None: return update_user_data(user_id, instance.replace_bot(data)) - def update_chat_data_replace_bot(chat_id: int, data: Dict) -> None: + def update_chat_data_replace_bot(chat_id: int, data: CD) -> None: return update_chat_data(chat_id, instance.replace_bot(data)) - def update_bot_data_replace_bot(data: Dict) -> None: + def update_bot_data_replace_bot(data: BD) -> None: return update_bot_data(instance.replace_bot(data)) # We want to ignore TGDeprecation warnings so we use obj.__setattr__. Adds to __dict__ @@ -334,33 +343,33 @@ class BasePersistence(ABC): return obj @abstractmethod - def get_user_data(self) -> DefaultDict[int, Dict[object, object]]: + def get_user_data(self) -> DefaultDict[int, UD]: """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the ``user_data`` if stored, or an empty - ``defaultdict(dict)``. + :obj:`defaultdict(telegram.ext.utils.types.UD)` with integer keys. Returns: - :obj:`defaultdict`: The restored user data. + DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data. """ @abstractmethod - def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: + def get_chat_data(self) -> DefaultDict[int, CD]: """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the ``chat_data`` if stored, or an empty - ``defaultdict(dict)``. + :obj:`defaultdict(telegram.ext.utils.types.CD)` with integer keys. Returns: - :obj:`defaultdict`: The restored chat data. + DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data. """ @abstractmethod - def get_bot_data(self) -> Dict[object, object]: + def get_bot_data(self) -> BD: """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the ``bot_data`` if stored, or an empty - :obj:`dict`. + :class:`telegram.ext.utils.types.BD`. Returns: - :obj:`dict`: The restored bot data. + :class:`telegram.ext.utils.types.BD`: The restored bot data. """ @abstractmethod @@ -391,32 +400,70 @@ class BasePersistence(ABC): """ @abstractmethod - def update_user_data(self, user_id: int, data: Dict) -> None: + def update_user_data(self, user_id: int, data: UD) -> None: """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. Args: user_id (:obj:`int`): The user the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. + data (:class:`telegram.ext.utils.types.UD`): The + :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``. """ @abstractmethod - def update_chat_data(self, chat_id: int, data: Dict) -> None: + def update_chat_data(self, chat_id: int, data: CD) -> None: """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. Args: chat_id (:obj:`int`): The chat the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. + data (:class:`telegram.ext.utils.types.CD`): The + :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``. """ @abstractmethod - def update_bot_data(self, data: Dict) -> None: + def update_bot_data(self, data: BD) -> None: """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. Args: - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data` . + data (:class:`telegram.ext.utils.types.BD`): The + :attr:`telegram.ext.dispatcher.bot_data`. + """ + + def refresh_user_data(self, user_id: int, user_data: UD) -> None: + """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + :attr:`user_data` to a callback. Can be used to update data stored in :attr:`user_data` + from an external source. + + .. versionadded:: 13.6 + + Args: + user_id (:obj:`int`): The user ID this :attr:`user_data` is associated with. + user_data (:class:`telegram.ext.utils.types.UD`): The ``user_data`` of a single user. + """ + + def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: + """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + :attr:`chat_data` to a callback. Can be used to update data stored in :attr:`chat_data` + from an external source. + + .. versionadded:: 13.6 + + Args: + chat_id (:obj:`int`): The chat ID this :attr:`chat_data` is associated with. + chat_data (:class:`telegram.ext.utils.types.CD`): The ``chat_data`` of a single chat. + """ + + def refresh_bot_data(self, bot_data: BD) -> None: + """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + :attr:`bot_data` to a callback. Can be used to update data stored in :attr:`bot_data` + from an external source. + + .. versionadded:: 13.6 + + Args: + bot_data (:class:`telegram.ext.utils.types.BD`): The ``bot_data``. """ def flush(self) -> None: diff --git a/telegram/ext/callbackcontext.py b/telegram/ext/callbackcontext.py index d27840b2e..626af5f83 100644 --- a/telegram/ext/callbackcontext.py +++ b/telegram/ext/callbackcontext.py @@ -19,16 +19,31 @@ # pylint: disable=R0201 """This module contains the CallbackContext class.""" from queue import Queue -from typing import TYPE_CHECKING, Dict, List, Match, NoReturn, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Match, + NoReturn, + Optional, + Tuple, + Union, + Generic, + Type, + TypeVar, +) from telegram import Update +from telegram.ext.utils.types import UD, CD, BD if TYPE_CHECKING: from telegram import Bot from telegram.ext import Dispatcher, Job, JobQueue +CC = TypeVar('CC', bound='CallbackContext') -class CallbackContext: + +class CallbackContext(Generic[UD, CD, BD]): """ This is a context object passed to the callback called by :class:`telegram.ext.Handler` or by the :class:`telegram.ext.Dispatcher` in an error handler added by @@ -50,6 +65,9 @@ class CallbackContext: almost certainly execute the callbacks for an update out of order, and the attributes that you think you added will not be present. + Args: + dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this context. + Attributes: matches (List[:obj:`re match object`]): Optional. If the associated update originated from a regex-supported handler or had a :class:`Filters.regex`, this will contain a list of @@ -75,9 +93,8 @@ class CallbackContext: __slots__ = ( '_dispatcher', - '_bot_data', - '_chat_data', - '_user_data', + '_chat_id_and_data', + '_user_id_and_data', 'args', 'matches', 'error', @@ -97,9 +114,8 @@ class CallbackContext: 'CallbackContext should not be used with a non context aware ' 'dispatcher!' ) self._dispatcher = dispatcher - self._bot_data = dispatcher.bot_data - self._chat_data: Optional[Dict[object, object]] = None - self._user_data: Optional[Dict[object, object]] = None + self._chat_id_and_data: Optional[Tuple[int, CD]] = None + self._user_id_and_data: Optional[Tuple[int, UD]] = None self.args: Optional[List[str]] = None self.matches: Optional[List[Match]] = None self.error: Optional[Exception] = None @@ -113,11 +129,11 @@ class CallbackContext: return self._dispatcher @property - def bot_data(self) -> Dict: + def bot_data(self) -> BD: """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each update it will be the same ``dict``. """ - return self._bot_data + return self.dispatcher.bot_data @bot_data.setter def bot_data(self, value: object) -> NoReturn: @@ -126,7 +142,7 @@ class CallbackContext: ) @property - def chat_data(self) -> Optional[Dict]: + def chat_data(self) -> Optional[CD]: """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each update from the same chat id it will be the same ``dict``. @@ -136,7 +152,9 @@ class CallbackContext: `_. """ - return self._chat_data + if self._chat_id_and_data: + return self._chat_id_and_data[1] + return None @chat_data.setter def chat_data(self, value: object) -> NoReturn: @@ -145,11 +163,13 @@ class CallbackContext: ) @property - def user_data(self) -> Optional[Dict]: + def user_data(self) -> Optional[UD]: """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each update from the same user it will be the same ``dict``. """ - return self._user_data + if self._user_id_and_data: + return self._user_id_and_data[1] + return None @user_data.setter def user_data(self, value: object) -> NoReturn: @@ -157,15 +177,32 @@ class CallbackContext: "You can not assign a new value to user_data, see https://git.io/Jt6ic" ) + def refresh_data(self) -> None: + """If :attr:`dispatcher` uses persistence, calls + :meth:`telegram.ext.BasePersistence.refresh_bot_data` on :attr:`bot_data`, + :meth:`telegram.ext.BasePersistence.refresh_chat_data` on :attr:`chat_data` and + :meth:`telegram.ext.BasePersistence.refresh_user_data` on :attr:`user_data`, if + appropriate. + + .. versionadded:: 13.6 + """ + if self.dispatcher.persistence: + if self.dispatcher.persistence.store_bot_data: + self.dispatcher.persistence.refresh_bot_data(self.bot_data) + if self.dispatcher.persistence.store_chat_data and self._chat_id_and_data is not None: + self.dispatcher.persistence.refresh_chat_data(*self._chat_id_and_data) + if self.dispatcher.persistence.store_user_data and self._user_id_and_data is not None: + self.dispatcher.persistence.refresh_user_data(*self._user_id_and_data) + @classmethod def from_error( - cls, + cls: Type[CC], update: object, error: Exception, dispatcher: 'Dispatcher', async_args: Union[List, Tuple] = None, async_kwargs: Dict[str, object] = None, - ) -> 'CallbackContext': + ) -> CC: """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error handlers. @@ -195,7 +232,7 @@ class CallbackContext: return self @classmethod - def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CallbackContext': + def from_update(cls: Type[CC], update: object, dispatcher: 'Dispatcher') -> CC: """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the handlers. @@ -217,13 +254,19 @@ class CallbackContext: user = update.effective_user if chat: - self._chat_data = dispatcher.chat_data[chat.id] # pylint: disable=W0212 + self._chat_id_and_data = ( + chat.id, + dispatcher.chat_data[chat.id], # pylint: disable=W0212 + ) if user: - self._user_data = dispatcher.user_data[user.id] # pylint: disable=W0212 + self._user_id_and_data = ( + user.id, + dispatcher.user_data[user.id], # pylint: disable=W0212 + ) return self @classmethod - def from_job(cls, job: 'Job', dispatcher: 'Dispatcher') -> 'CallbackContext': + def from_job(cls: Type[CC], job: 'Job', dispatcher: 'Dispatcher') -> CC: """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to a job callback. diff --git a/telegram/ext/callbackqueryhandler.py b/telegram/ext/callbackqueryhandler.py index 95eb8afb1..452578049 100644 --- a/telegram/ext/callbackqueryhandler.py +++ b/telegram/ext/callbackqueryhandler.py @@ -35,14 +35,15 @@ from telegram import Update from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class CallbackQueryHandler(Handler[Update]): +class CallbackQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram callback queries. Optionally based on a regex. Read the documentation of the ``re`` module for more information. @@ -124,7 +125,7 @@ class CallbackQueryHandler(Handler[Update]): def __init__( self, - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], pass_update_queue: bool = False, pass_job_queue: bool = False, pattern: Union[str, Pattern] = None, @@ -191,7 +192,7 @@ class CallbackQueryHandler(Handler[Update]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: Update, dispatcher: 'Dispatcher', check_result: Union[bool, Match], diff --git a/telegram/ext/chatmemberhandler.py b/telegram/ext/chatmemberhandler.py index c098db904..9499cfd24 100644 --- a/telegram/ext/chatmemberhandler.py +++ b/telegram/ext/chatmemberhandler.py @@ -17,19 +17,17 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the ChatMemberHandler classes.""" -from typing import ClassVar, TypeVar, Union, Callable, TYPE_CHECKING +from typing import ClassVar, TypeVar, Union, Callable from telegram import Update from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler - -if TYPE_CHECKING: - from telegram.ext import CallbackContext +from .utils.types import CCT RT = TypeVar('RT') -class ChatMemberHandler(Handler[Update]): +class ChatMemberHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a chat member update. .. versionadded:: 13.4 @@ -107,7 +105,7 @@ class ChatMemberHandler(Handler[Update]): def __init__( self, - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], chat_member_types: int = MY_CHAT_MEMBER, pass_update_queue: bool = False, pass_job_queue: bool = False, diff --git a/telegram/ext/choseninlineresulthandler.py b/telegram/ext/choseninlineresulthandler.py index 2ae2814f6..ec3528945 100644 --- a/telegram/ext/choseninlineresulthandler.py +++ b/telegram/ext/choseninlineresulthandler.py @@ -24,6 +24,7 @@ from telegram import Update from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT RT = TypeVar('RT') @@ -31,7 +32,7 @@ if TYPE_CHECKING: from telegram.ext import CallbackContext, Dispatcher -class ChosenInlineResultHandler(Handler[Update]): +class ChosenInlineResultHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a chosen inline result. Note: diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py index 6d6a1b6e6..1f0a32118 100644 --- a/telegram/ext/commandhandler.py +++ b/telegram/ext/commandhandler.py @@ -27,15 +27,16 @@ from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.types import SLT from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE +from .utils.types import CCT from .handler import Handler if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class CommandHandler(Handler[Update]): +class CommandHandler(Handler[Update, CCT]): """Handler class to handle Telegram commands. Commands are Telegram messages that start with ``/``, optionally followed by an ``@`` and the @@ -134,7 +135,7 @@ class CommandHandler(Handler[Update]): def __init__( self, command: SLT[str], - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], filters: BaseFilter = None, allow_edited: bool = None, pass_args: bool = False, @@ -231,7 +232,7 @@ class CommandHandler(Handler[Update]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: Update, dispatcher: 'Dispatcher', check_result: Optional[Union[bool, Tuple[List[str], Optional[bool]]]], @@ -359,7 +360,7 @@ class PrefixHandler(CommandHandler): self, prefix: SLT[str], command: SLT[str], - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], filters: BaseFilter = None, pass_args: bool = False, pass_update_queue: bool = False, diff --git a/telegram/ext/contexttypes.py b/telegram/ext/contexttypes.py new file mode 100644 index 000000000..2156e7f62 --- /dev/null +++ b/telegram/ext/contexttypes.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2020 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +# pylint: disable=R0201 +"""This module contains the auxiliary class ContextTypes.""" +from typing import Type, Generic, overload, Dict # pylint: disable=W0611 + +from telegram.ext.callbackcontext import CallbackContext +from telegram.ext.utils.types import CCT, UD, CD, BD + + +class ContextTypes(Generic[CCT, UD, CD, BD]): + """ + Convenience class to gather customizable types of the :class:`telegram.ext.CallbackContext` + interface. + + .. versionadded:: 13.6 + + Args: + context (:obj:`type`, optional): Determines the type of the ``context`` argument of all + (error-)handler callbacks and job callbacks. Must be a subclass of + :class:`telegram.ext.CallbackContext`. Defaults to + :class:`telegram.ext.CallbackContext`. + bot_data (:obj:`type`, optional): Determines the type of ``context.bot_data`` of all + (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support + instantiating without arguments. + chat_data (:obj:`type`, optional): Determines the type of ``context.chat_data`` of all + (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support + instantiating without arguments. + user_data (:obj:`type`, optional): Determines the type of ``context.user_data`` of all + (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support + instantiating without arguments. + + """ + + __slots__ = ('_context', '_bot_data', '_chat_data', '_user_data') + + @overload + def __init__( + self: 'ContextTypes[CallbackContext, Dict, Dict, Dict]', + ): + ... + + @overload + def __init__(self: 'ContextTypes[CCT, Dict, Dict, Dict]', context: Type[CCT]): + ... + + @overload + def __init__(self: 'ContextTypes[CallbackContext, UD, Dict, Dict]', bot_data: Type[UD]): + ... + + @overload + def __init__(self: 'ContextTypes[CallbackContext, Dict, CD, Dict]', chat_data: Type[CD]): + ... + + @overload + def __init__(self: 'ContextTypes[CallbackContext, Dict, Dict, BD]', user_data: Type[BD]): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, UD, Dict, Dict]', context: Type[CCT], bot_data: Type[UD] + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, Dict, CD, Dict]', context: Type[CCT], chat_data: Type[CD] + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, Dict, Dict, BD]', context: Type[CCT], user_data: Type[BD] + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CallbackContext, UD, CD, Dict]', + bot_data: Type[UD], + chat_data: Type[CD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CallbackContext, UD, Dict, BD]', + bot_data: Type[UD], + user_data: Type[BD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CallbackContext, Dict, CD, BD]', + chat_data: Type[CD], + user_data: Type[BD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, UD, CD, Dict]', + context: Type[CCT], + bot_data: Type[UD], + chat_data: Type[CD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, UD, Dict, BD]', + context: Type[CCT], + bot_data: Type[UD], + user_data: Type[BD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, Dict, CD, BD]', + context: Type[CCT], + chat_data: Type[CD], + user_data: Type[BD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CallbackContext, UD, CD, BD]', + bot_data: Type[UD], + chat_data: Type[CD], + user_data: Type[BD], + ): + ... + + @overload + def __init__( + self: 'ContextTypes[CCT, UD, CD, BD]', + context: Type[CCT], + bot_data: Type[UD], + chat_data: Type[CD], + user_data: Type[BD], + ): + ... + + def __init__( # type: ignore[no-untyped-def] + self, + context=CallbackContext, + bot_data=dict, + chat_data=dict, + user_data=dict, + ): + if not issubclass(context, CallbackContext): + raise ValueError('context must be a subclass of CallbackContext.') + + # We make all those only accessible via properties because we don't currently support + # changing this at runtime, so overriding the attributes doesn't make sense + self._context = context + self._bot_data = bot_data + self._chat_data = chat_data + self._user_data = user_data + + @property + def context(self) -> Type[CCT]: + return self._context + + @property + def bot_data(self) -> Type[BD]: + return self._bot_data + + @property + def chat_data(self) -> Type[CD]: + return self._chat_data + + @property + def user_data(self) -> Type[UD]: + return self._user_data diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index ffaaf969c..081e10f95 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -38,6 +38,7 @@ from telegram.ext import ( ) from telegram.ext.utils.promise import Promise from telegram.utils.types import ConversationDict +from telegram.ext.utils.types import CCT if TYPE_CHECKING: from telegram.ext import Dispatcher, Job @@ -61,7 +62,7 @@ class _ConversationTimeoutContext: self.callback_context = callback_context -class ConversationHandler(Handler[Update]): +class ConversationHandler(Handler[Update, CCT]): """ A handler to hold a conversation with a single or multiple users through Telegram updates by managing four collections of other handlers. @@ -215,9 +216,9 @@ class ConversationHandler(Handler[Update]): # pylint: disable=W0231 def __init__( self, - entry_points: List[Handler], - states: Dict[object, List[Handler]], - fallbacks: List[Handler], + entry_points: List[Handler[Update, CCT]], + states: Dict[object, List[Handler[Update, CCT]]], + fallbacks: List[Handler[Update, CCT]], allow_reentry: bool = False, per_chat: bool = True, per_user: bool = True, diff --git a/telegram/ext/dictpersistence.py b/telegram/ext/dictpersistence.py index e5df61605..ad9360442 100644 --- a/telegram/ext/dictpersistence.py +++ b/telegram/ext/dictpersistence.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the DictPersistence class.""" -from copy import deepcopy from typing import DefaultDict, Dict, Optional, Tuple from collections import defaultdict @@ -37,7 +36,7 @@ except ImportError: class DictPersistence(BasePersistence): - """Using python's dicts and json for making your bot persistent. + """Using Python's :obj:`dict` and ``json`` for making your bot persistent. Note: This class does *not* implement a :meth:`flush` method, meaning that data managed by @@ -202,7 +201,7 @@ class DictPersistence(BasePersistence): pass else: self._user_data = defaultdict(dict) - return deepcopy(self.user_data) # type: ignore[arg-type] + return self.user_data # type: ignore[return-value] def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: """Returns the chat_data created from the ``chat_data_json`` or an empty @@ -215,7 +214,7 @@ class DictPersistence(BasePersistence): pass else: self._chat_data = defaultdict(dict) - return deepcopy(self.chat_data) # type: ignore[arg-type] + return self.chat_data # type: ignore[return-value] def get_bot_data(self) -> Dict[object, object]: """Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`. @@ -227,7 +226,7 @@ class DictPersistence(BasePersistence): pass else: self._bot_data = {} - return deepcopy(self.bot_data) # type: ignore[arg-type] + return self.bot_data # type: ignore[return-value] def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations created from the ``conversations_json`` or an empty @@ -264,7 +263,7 @@ class DictPersistence(BasePersistence): Args: user_id (:obj:`int`): The user the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``. """ if self._user_data is None: self._user_data = defaultdict(dict) @@ -278,7 +277,7 @@ class DictPersistence(BasePersistence): Args: chat_id (:obj:`int`): The chat the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``. """ if self._chat_data is None: self._chat_data = defaultdict(dict) @@ -295,5 +294,26 @@ class DictPersistence(BasePersistence): """ if self._bot_data == data: return - self._bot_data = data.copy() + self._bot_data = data self._bot_data_json = None + + def refresh_user_data(self, user_id: int, user_data: Dict) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data` + """ + + def refresh_chat_data(self, chat_id: int, chat_data: Dict) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data` + """ + + def refresh_bot_data(self, bot_data: Dict) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data` + """ diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index 4dbd2382d..db5f0958a 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -26,16 +26,30 @@ from functools import wraps from queue import Empty, Queue from threading import BoundedSemaphore, Event, Lock, Thread, current_thread from time import sleep -from typing import TYPE_CHECKING, Callable, DefaultDict, Dict, List, Optional, Union, Set +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Set, + Union, + Generic, + TypeVar, + overload, + cast, + DefaultDict, +) from uuid import uuid4 from telegram import TelegramError, Update -from telegram.ext import BasePersistence +from telegram.ext import BasePersistence, ContextTypes from telegram.ext.callbackcontext import CallbackContext from telegram.ext.handler import Handler from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated from telegram.ext.utils.promise import Promise from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE +from telegram.ext.utils.types import CCT, UD, CD, BD if TYPE_CHECKING: from telegram import Bot @@ -43,6 +57,8 @@ if TYPE_CHECKING: DEFAULT_GROUP: int = 0 +UT = TypeVar('UT') + def run_async( func: Callable[[Update, CallbackContext], object] @@ -105,7 +121,7 @@ class DispatcherHandlerStop(Exception): self.state = state -class Dispatcher: +class Dispatcher(Generic[CCT, UD, CD, BD]): """This class dispatches all kinds of updates to its registered handlers. Args: @@ -120,6 +136,12 @@ class Dispatcher: use_context (:obj:`bool`, optional): If set to :obj:`True` uses the context based callback API (ignored if `dispatcher` argument is used). Defaults to :obj:`True`. **New users**: set this to :obj:`True`. + context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance + of :class:`telegram.ext.ContextTypes` to customize the types used in the + ``context`` interface. If not passed, the defaults documented in + :class:`telegram.ext.ContextTypes` will be used. + + .. versionadded:: 13.6 Attributes: bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. @@ -133,6 +155,10 @@ class Dispatcher: bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to store data that should be persistent over restarts. + context_types (:class:`telegram.ext.ContextTypes`): Container for the types used + in the ``context`` interface. + + .. versionadded:: 13.6 """ @@ -158,6 +184,7 @@ class Dispatcher: 'bot', '__dict__', '__weakref__', + 'context_types', ) __singleton_lock = Lock() @@ -165,6 +192,33 @@ class Dispatcher: __singleton = None logger = logging.getLogger(__name__) + @overload + def __init__( + self: 'Dispatcher[CallbackContext[Dict, Dict, Dict], Dict, Dict, Dict]', + bot: 'Bot', + update_queue: Queue, + workers: int = 4, + exception_event: Event = None, + job_queue: 'JobQueue' = None, + persistence: BasePersistence = None, + use_context: bool = True, + ): + ... + + @overload + def __init__( + self: 'Dispatcher[CCT, UD, CD, BD]', + bot: 'Bot', + update_queue: Queue, + workers: int = 4, + exception_event: Event = None, + job_queue: 'JobQueue' = None, + persistence: BasePersistence = None, + use_context: bool = True, + context_types: ContextTypes[CCT, UD, CD, BD] = None, + ): + ... + def __init__( self, bot: 'Bot', @@ -174,12 +228,14 @@ class Dispatcher: job_queue: 'JobQueue' = None, persistence: BasePersistence = None, use_context: bool = True, + context_types: ContextTypes[CCT, UD, CD, BD] = None, ): self.bot = bot self.update_queue = update_queue self.job_queue = job_queue self.workers = workers self.use_context = use_context + self.context_types = cast(ContextTypes[CCT, UD, CD, BD], context_types or ContextTypes()) if not use_context: warnings.warn( @@ -193,9 +249,9 @@ class Dispatcher: 'Asynchronous callbacks can not be processed without at least one worker thread.' ) - self.user_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict) - self.chat_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict) - self.bot_data = {} + self.user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) + self.chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) + self.bot_data = self.context_types.bot_data() self.persistence: Optional[BasePersistence] = None self._update_persistence_lock = Lock() if persistence: @@ -213,8 +269,11 @@ class Dispatcher: raise ValueError("chat_data must be of type defaultdict") if self.persistence.store_bot_data: self.bot_data = self.persistence.get_bot_data() - if not isinstance(self.bot_data, dict): - raise ValueError("bot_data must be of type dict") + if not isinstance(self.bot_data, self.context_types.bot_data): + raise ValueError( + f"bot_data must be of type {self.context_types.bot_data.__name__}" + ) + else: self.persistence = None @@ -477,7 +536,8 @@ class Dispatcher: check = handler.check_update(update) if check is not None and check is not False: if not context and self.use_context: - context = CallbackContext.from_update(update, self) + context = self.context_types.context.from_update(update, self) + context.refresh_data() handled = True sync_modes.append(handler.run_async) handler.handle_update(update, self, check, context) @@ -510,7 +570,7 @@ class Dispatcher: if not handled_only_async: self.update_persistence(update=update) - def add_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: + def add_handler(self, handler: Handler[UT, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of @@ -542,14 +602,22 @@ class Dispatcher: raise TypeError(f'handler is not an instance of {Handler.__name__}') if not isinstance(group, int): raise TypeError('group is not int') - if isinstance(handler, ConversationHandler) and handler.persistent and handler.name: + # For some reason MyPy infers the type of handler is here, + # so for now we just ignore all the errors + if ( + isinstance(handler, ConversationHandler) + and handler.persistent # type: ignore[attr-defined] + and handler.name # type: ignore[attr-defined] + ): if not self.persistence: raise ValueError( - f"ConversationHandler {handler.name} can not be persistent if dispatcher has " - f"no persistence" + f"ConversationHandler {handler.name} " # type: ignore[attr-defined] + f"can not be persistent if dispatcher has no persistence" ) - handler.persistence = self.persistence - handler.conversations = self.persistence.get_conversations(handler.name) + handler.persistence = self.persistence # type: ignore[attr-defined] + handler.conversations = ( # type: ignore[attr-defined] + self.persistence.get_conversations(handler.name) # type: ignore[attr-defined] + ) if group not in self.handlers: self.handlers[group] = [] @@ -643,7 +711,7 @@ class Dispatcher: def add_error_handler( self, - callback: Callable[[object, CallbackContext], None], + callback: Callable[[object, CCT], None], run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, # pylint: disable=W0621 ) -> None: """Registers an error handler in the Dispatcher. This handler will receive every error @@ -678,7 +746,7 @@ class Dispatcher: self.error_handlers[callback] = run_async - def remove_error_handler(self, callback: Callable[[object, CallbackContext], None]) -> None: + def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: """Removes an error handler. Args: @@ -705,7 +773,7 @@ class Dispatcher: if self.error_handlers: for callback, run_async in self.error_handlers.items(): # pylint: disable=W0621 if self.use_context: - context = CallbackContext.from_error( + context = self.context_types.context.from_error( update, error, self, async_args=async_args, async_kwargs=async_kwargs ) if run_async: diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index 17a86d166..befaf4139 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -26,15 +26,16 @@ from telegram.utils.deprecate import set_new_attribute_deprecated from telegram import Update from telegram.ext.utils.promise import Promise from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE +from telegram.ext.utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') UT = TypeVar('UT') -class Handler(Generic[UT], ABC): +class Handler(Generic[UT, CCT], ABC): """The base class for all update handlers. Create custom handlers by inheriting from it. Note: @@ -115,7 +116,7 @@ class Handler(Generic[UT], ABC): def __init__( self, - callback: Callable[[UT, 'CallbackContext'], RT], + callback: Callable[[UT, CCT], RT], pass_update_queue: bool = False, pass_job_queue: bool = False, pass_user_data: bool = False, @@ -165,7 +166,7 @@ class Handler(Generic[UT], ABC): update: UT, dispatcher: 'Dispatcher', check_result: object, - context: 'CallbackContext' = None, + context: CCT = None, ) -> Union[RT, Promise]: """ This method is called if it was determined that an update should indeed @@ -205,7 +206,7 @@ class Handler(Generic[UT], ABC): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: UT, dispatcher: 'Dispatcher', check_result: Any, diff --git a/telegram/ext/inlinequeryhandler.py b/telegram/ext/inlinequeryhandler.py index 3216345c3..11103e71f 100644 --- a/telegram/ext/inlinequeryhandler.py +++ b/telegram/ext/inlinequeryhandler.py @@ -35,14 +35,15 @@ from telegram import Update from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class InlineQueryHandler(Handler[Update]): +class InlineQueryHandler(Handler[Update, CCT]): """ Handler class to handle Telegram inline queries. Optionally based on a regex. Read the documentation of the ``re`` module for more information. @@ -133,7 +134,7 @@ class InlineQueryHandler(Handler[Update]): def __init__( self, - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], pass_update_queue: bool = False, pass_job_queue: bool = False, pattern: Union[str, Pattern] = None, @@ -207,7 +208,7 @@ class InlineQueryHandler(Handler[Update]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: Update, dispatcher: 'Dispatcher', check_result: Optional[Union[bool, Match]], diff --git a/telegram/ext/jobqueue.py b/telegram/ext/jobqueue.py index 2ea2fec5c..837cac561 100644 --- a/telegram/ext/jobqueue.py +++ b/telegram/ext/jobqueue.py @@ -72,7 +72,7 @@ class JobQueue: def _build_args(self, job: 'Job') -> List[Union[CallbackContext, 'Bot', 'Job']]: if self._dispatcher.use_context: - return [CallbackContext.from_job(job, self._dispatcher)] + return [self._dispatcher.context_types.context.from_job(job, self._dispatcher)] return [self._dispatcher.bot, job] def _tz_now(self) -> datetime.datetime: @@ -585,7 +585,7 @@ class Job: """Executes the callback function independently of the jobs schedule.""" try: if dispatcher.use_context: - self.callback(CallbackContext.from_job(self, dispatcher)) + self.callback(dispatcher.context_types.context.from_job(self, dispatcher)) else: self.callback(dispatcher.bot, self) # type: ignore[arg-type,call-arg] except Exception as exc: diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py index ed1078124..c3f0c015c 100644 --- a/telegram/ext/messagehandler.py +++ b/telegram/ext/messagehandler.py @@ -27,14 +27,15 @@ from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class MessageHandler(Handler[Update]): +class MessageHandler(Handler[Update, CCT]): """Handler class to handle telegram messages. They might contain text, media or status updates. Note: @@ -125,7 +126,7 @@ class MessageHandler(Handler[Update]): def __init__( self, filters: BaseFilter, - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], pass_update_queue: bool = False, pass_job_queue: bool = False, pass_user_data: bool = False, @@ -197,7 +198,7 @@ class MessageHandler(Handler[Update]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: Update, dispatcher: 'Dispatcher', check_result: Optional[Union[bool, Dict[str, object]]], diff --git a/telegram/ext/picklepersistence.py b/telegram/ext/picklepersistence.py index 3127d3baf..d015924b7 100644 --- a/telegram/ext/picklepersistence.py +++ b/telegram/ext/picklepersistence.py @@ -19,14 +19,23 @@ """This module contains the PicklePersistence class.""" import pickle from collections import defaultdict -from copy import deepcopy -from typing import Any, DefaultDict, Dict, Optional, Tuple +from typing import ( + Any, + Dict, + Optional, + Tuple, + overload, + cast, + DefaultDict, +) from telegram.ext import BasePersistence -from telegram.utils.types import ConversationDict +from telegram.utils.types import ConversationDict # pylint: disable=W0611 +from .utils.types import UD, CD, BD +from .contexttypes import ContextTypes -class PicklePersistence(BasePersistence): +class PicklePersistence(BasePersistence[UD, CD, BD]): """Using python's builtin pickle for making you bot persistent. Warning: @@ -54,6 +63,12 @@ class PicklePersistence(BasePersistence): :meth:`flush` is called and keep data in memory until that happens. When :obj:`False` will store data on any transaction *and* on call to :meth:`flush`. Default is :obj:`False`. + context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance + of :class:`telegram.ext.ContextTypes` to customize the types used in the + ``context`` interface. If not passed, the defaults documented in + :class:`telegram.ext.ContextTypes` will be used. + + .. versionadded:: 13.6 Attributes: filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file` @@ -71,6 +86,10 @@ class PicklePersistence(BasePersistence): :meth:`flush` is called and keep data in memory until that happens. When :obj:`False` will store data on any transaction *and* on call to :meth:`flush`. Default is :obj:`False`. + context_types (:class:`telegram.ext.ContextTypes`): Container for the types used + in the ``context`` interface. + + .. versionadded:: 13.6 """ __slots__ = ( @@ -81,8 +100,34 @@ class PicklePersistence(BasePersistence): 'chat_data', 'bot_data', 'conversations', + 'context_types', ) + @overload + def __init__( + self: 'PicklePersistence[Dict, Dict, Dict]', + filename: str, + store_user_data: bool = True, + store_chat_data: bool = True, + store_bot_data: bool = True, + single_file: bool = True, + on_flush: bool = False, + ): + ... + + @overload + def __init__( + self: 'PicklePersistence[UD, CD, BD]', + filename: str, + store_user_data: bool = True, + store_chat_data: bool = True, + store_bot_data: bool = True, + single_file: bool = True, + on_flush: bool = False, + context_types: ContextTypes[Any, UD, CD, BD] = None, + ): + ... + def __init__( self, filename: str, @@ -91,6 +136,7 @@ class PicklePersistence(BasePersistence): store_bot_data: bool = True, single_file: bool = True, on_flush: bool = False, + context_types: ContextTypes[Any, UD, CD, BD] = None, ): super().__init__( store_user_data=store_user_data, @@ -100,26 +146,27 @@ class PicklePersistence(BasePersistence): self.filename = filename self.single_file = single_file self.on_flush = on_flush - self.user_data: Optional[DefaultDict[int, Dict]] = None - self.chat_data: Optional[DefaultDict[int, Dict]] = None - self.bot_data: Optional[Dict] = None + self.user_data: Optional[DefaultDict[int, UD]] = None + self.chat_data: Optional[DefaultDict[int, CD]] = None + self.bot_data: Optional[BD] = None self.conversations: Optional[Dict[str, Dict[Tuple, object]]] = None + self.context_types = cast(ContextTypes[Any, UD, CD, BD], context_types or ContextTypes()) def _load_singlefile(self) -> None: try: filename = self.filename with open(self.filename, "rb") as file: data = pickle.load(file) - self.user_data = defaultdict(dict, data['user_data']) - self.chat_data = defaultdict(dict, data['chat_data']) + self.user_data = defaultdict(self.context_types.user_data, data['user_data']) + self.chat_data = defaultdict(self.context_types.chat_data, data['chat_data']) # For backwards compatibility with files not containing bot data - self.bot_data = data.get('bot_data', {}) + self.bot_data = data.get('bot_data', self.context_types.bot_data()) self.conversations = data['conversations'] except OSError: self.conversations = {} - self.user_data = defaultdict(dict) - self.chat_data = defaultdict(dict) - self.bot_data = {} + self.user_data = defaultdict(self.context_types.user_data) + self.chat_data = defaultdict(self.context_types.chat_data) + self.bot_data = self.context_types.bot_data() except pickle.UnpicklingError as exc: raise TypeError(f"File {filename} does not contain valid pickle data") from exc except Exception as exc: @@ -152,11 +199,11 @@ class PicklePersistence(BasePersistence): with open(filename, "wb") as file: pickle.dump(data, file) - def get_user_data(self) -> DefaultDict[int, Dict[object, object]]: + def get_user_data(self) -> DefaultDict[int, UD]: """Returns the user_data from the pickle file if it exists or an empty :obj:`defaultdict`. Returns: - :obj:`defaultdict`: The restored user data. + DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data. """ if self.user_data: pass @@ -164,19 +211,19 @@ class PicklePersistence(BasePersistence): filename = f"{self.filename}_user_data" data = self._load_file(filename) if not data: - data = defaultdict(dict) + data = defaultdict(self.context_types.user_data) else: - data = defaultdict(dict, data) + data = defaultdict(self.context_types.user_data, data) self.user_data = data else: self._load_singlefile() - return deepcopy(self.user_data) # type: ignore[arg-type] + return self.user_data # type: ignore[return-value] - def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: + def get_chat_data(self) -> DefaultDict[int, CD]: """Returns the chat_data from the pickle file if it exists or an empty :obj:`defaultdict`. Returns: - :obj:`defaultdict`: The restored chat data. + DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data. """ if self.chat_data: pass @@ -184,19 +231,20 @@ class PicklePersistence(BasePersistence): filename = f"{self.filename}_chat_data" data = self._load_file(filename) if not data: - data = defaultdict(dict) + data = defaultdict(self.context_types.chat_data) else: - data = defaultdict(dict, data) + data = defaultdict(self.context_types.chat_data, data) self.chat_data = data else: self._load_singlefile() - return deepcopy(self.chat_data) # type: ignore[arg-type] + return self.chat_data # type: ignore[return-value] - def get_bot_data(self) -> Dict[object, object]: - """Returns the bot_data from the pickle file if it exists or an empty :obj:`dict`. + def get_bot_data(self) -> BD: + """Returns the bot_data from the pickle file if it exists or an empty object of type + :class:`telegram.ext.utils.types.BD`. Returns: - :obj:`dict`: The restored bot data. + :class:`telegram.ext.utils.types.BD`: The restored bot data. """ if self.bot_data: pass @@ -204,11 +252,11 @@ class PicklePersistence(BasePersistence): filename = f"{self.filename}_bot_data" data = self._load_file(filename) if not data: - data = {} + data = self.context_types.bot_data() self.bot_data = data else: self._load_singlefile() - return deepcopy(self.bot_data) # type: ignore[arg-type] + return self.bot_data # type: ignore[return-value] def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations from the pickle file if it exsists or an empty dict. @@ -254,15 +302,16 @@ class PicklePersistence(BasePersistence): else: self._dump_singlefile() - def update_user_data(self, user_id: int, data: Dict) -> None: + def update_user_data(self, user_id: int, data: UD) -> None: """Will update the user_data and depending on :attr:`on_flush` save the pickle file. Args: user_id (:obj:`int`): The user the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. + data (:class:`telegram.ext.utils.types.UD`): The + :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``. """ if self.user_data is None: - self.user_data = defaultdict(dict) + self.user_data = defaultdict(self.context_types.user_data) if self.user_data.get(user_id) == data: return self.user_data[user_id] = data @@ -273,15 +322,16 @@ class PicklePersistence(BasePersistence): else: self._dump_singlefile() - def update_chat_data(self, chat_id: int, data: Dict) -> None: + def update_chat_data(self, chat_id: int, data: CD) -> None: """Will update the chat_data and depending on :attr:`on_flush` save the pickle file. Args: chat_id (:obj:`int`): The chat the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. + data (:class:`telegram.ext.utils.types.CD`): The + :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``. """ if self.chat_data is None: - self.chat_data = defaultdict(dict) + self.chat_data = defaultdict(self.context_types.chat_data) if self.chat_data.get(chat_id) == data: return self.chat_data[chat_id] = data @@ -292,15 +342,16 @@ class PicklePersistence(BasePersistence): else: self._dump_singlefile() - def update_bot_data(self, data: Dict) -> None: + def update_bot_data(self, data: BD) -> None: """Will update the bot_data and depending on :attr:`on_flush` save the pickle file. Args: - data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data`. + data (:class:`telegram.ext.utils.types.BD`): The + :attr:`telegram.ext.dispatcher.bot_data`. """ if self.bot_data == data: return - self.bot_data = data.copy() + self.bot_data = data if not self.on_flush: if not self.single_file: filename = f"{self.filename}_bot_data" @@ -308,6 +359,27 @@ class PicklePersistence(BasePersistence): else: self._dump_singlefile() + def refresh_user_data(self, user_id: int, user_data: UD) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data` + """ + + def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data` + """ + + def refresh_bot_data(self, bot_data: BD) -> None: + """Does nothing. + + .. versionadded:: 13.6 + .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data` + """ + def flush(self) -> None: """Will save all data in memory to pickle file(s).""" if self.single_file: diff --git a/telegram/ext/pollanswerhandler.py b/telegram/ext/pollanswerhandler.py index 73cd36bce..199bcb3ad 100644 --- a/telegram/ext/pollanswerhandler.py +++ b/telegram/ext/pollanswerhandler.py @@ -22,9 +22,10 @@ from telegram import Update from .handler import Handler +from .utils.types import CCT -class PollAnswerHandler(Handler[Update]): +class PollAnswerHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a poll answer. Note: diff --git a/telegram/ext/pollhandler.py b/telegram/ext/pollhandler.py index c0719d5e5..7b67e76ff 100644 --- a/telegram/ext/pollhandler.py +++ b/telegram/ext/pollhandler.py @@ -22,9 +22,10 @@ from telegram import Update from .handler import Handler +from .utils.types import CCT -class PollHandler(Handler[Update]): +class PollHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a poll. Note: diff --git a/telegram/ext/precheckoutqueryhandler.py b/telegram/ext/precheckoutqueryhandler.py index 1a93cccca..3a2eee30d 100644 --- a/telegram/ext/precheckoutqueryhandler.py +++ b/telegram/ext/precheckoutqueryhandler.py @@ -22,9 +22,10 @@ from telegram import Update from .handler import Handler +from .utils.types import CCT -class PreCheckoutQueryHandler(Handler[Update]): +class PreCheckoutQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram PreCheckout callback queries. Note: diff --git a/telegram/ext/regexhandler.py b/telegram/ext/regexhandler.py index 50b82a0e8..399e4df7d 100644 --- a/telegram/ext/regexhandler.py +++ b/telegram/ext/regexhandler.py @@ -26,9 +26,10 @@ from telegram import Update from telegram.ext import Filters, MessageHandler from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE +from telegram.ext.utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') @@ -113,7 +114,7 @@ class RegexHandler(MessageHandler): def __init__( self, pattern: Union[str, Pattern], - callback: Callable[[Update, 'CallbackContext'], RT], + callback: Callable[[Update, CCT], RT], pass_groups: bool = False, pass_groupdict: bool = False, pass_update_queue: bool = False, diff --git a/telegram/ext/shippingqueryhandler.py b/telegram/ext/shippingqueryhandler.py index 53ea28494..e4229ceb7 100644 --- a/telegram/ext/shippingqueryhandler.py +++ b/telegram/ext/shippingqueryhandler.py @@ -21,9 +21,10 @@ from telegram import Update from .handler import Handler +from .utils.types import CCT -class ShippingQueryHandler(Handler[Update]): +class ShippingQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram shipping callback queries. Note: diff --git a/telegram/ext/stringcommandhandler.py b/telegram/ext/stringcommandhandler.py index 98aa6518b..1d84892e4 100644 --- a/telegram/ext/stringcommandhandler.py +++ b/telegram/ext/stringcommandhandler.py @@ -23,14 +23,15 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, Union from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class StringCommandHandler(Handler[str]): +class StringCommandHandler(Handler[str, CCT]): """Handler class to handle string commands. Commands are string updates that start with ``/``. The handler will add a ``list`` to the :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings, @@ -90,7 +91,7 @@ class StringCommandHandler(Handler[str]): def __init__( self, command: str, - callback: Callable[[str, 'CallbackContext'], RT], + callback: Callable[[str, CCT], RT], pass_args: bool = False, pass_update_queue: bool = False, pass_job_queue: bool = False, @@ -137,7 +138,7 @@ class StringCommandHandler(Handler[str]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: str, dispatcher: 'Dispatcher', check_result: Optional[List[str]], diff --git a/telegram/ext/stringregexhandler.py b/telegram/ext/stringregexhandler.py index db3ee8440..282c48ad7 100644 --- a/telegram/ext/stringregexhandler.py +++ b/telegram/ext/stringregexhandler.py @@ -24,14 +24,15 @@ from typing import TYPE_CHECKING, Callable, Dict, Match, Optional, Pattern, Type from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler +from .utils.types import CCT if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import Dispatcher RT = TypeVar('RT') -class StringRegexHandler(Handler[str]): +class StringRegexHandler(Handler[str, CCT]): """Handler class to handle string updates based on a regex which checks the update content. Read the documentation of the ``re`` module for more information. The ``re.match`` function is @@ -96,7 +97,7 @@ class StringRegexHandler(Handler[str]): def __init__( self, pattern: Union[str, Pattern], - callback: Callable[[str, 'CallbackContext'], RT], + callback: Callable[[str, CCT], RT], pass_groups: bool = False, pass_groupdict: bool = False, pass_update_queue: bool = False, @@ -153,7 +154,7 @@ class StringRegexHandler(Handler[str]): def collect_additional_context( self, - context: 'CallbackContext', + context: CCT, update: str, dispatcher: 'Dispatcher', check_result: Optional[Match], diff --git a/telegram/ext/typehandler.py b/telegram/ext/typehandler.py index a4f7d7319..531d10c30 100644 --- a/telegram/ext/typehandler.py +++ b/telegram/ext/typehandler.py @@ -18,19 +18,17 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the TypeHandler class.""" -from typing import TYPE_CHECKING, Callable, Type, TypeVar, Union +from typing import Callable, Type, TypeVar, Union from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from .handler import Handler - -if TYPE_CHECKING: - from telegram.ext import CallbackContext +from .utils.types import CCT RT = TypeVar('RT') UT = TypeVar('UT') -class TypeHandler(Handler[UT]): +class TypeHandler(Handler[UT, CCT]): """Handler class to handle updates of custom types. Warning: @@ -80,7 +78,7 @@ class TypeHandler(Handler[UT]): def __init__( self, type: Type[UT], # pylint: disable=W0622 - callback: Callable[[UT, 'CallbackContext'], RT], + callback: Callable[[UT, CCT], RT], strict: bool = False, pass_update_queue: bool = False, pass_job_queue: bool = False, diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index e9975d86f..30bb0f889 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -25,21 +25,34 @@ from queue import Queue from signal import SIGABRT, SIGINT, SIGTERM, signal from threading import Event, Lock, Thread, current_thread from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, no_type_check +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + no_type_check, + Generic, + overload, +) from telegram import Bot, TelegramError from telegram.error import InvalidToken, RetryAfter, TimedOut, Unauthorized -from telegram.ext import Dispatcher, JobQueue +from telegram.ext import Dispatcher, JobQueue, ContextTypes from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated from telegram.utils.helpers import get_signal_name from telegram.utils.request import Request +from telegram.ext.utils.types import CCT, UD, CD, BD from telegram.ext.utils.webhookhandler import WebhookAppClass, WebhookServer if TYPE_CHECKING: - from telegram.ext import BasePersistence, Defaults + from telegram.ext import BasePersistence, Defaults, CallbackContext -class Updater: +class Updater(Generic[CCT, UD, CD, BD]): """ This class, which employs the :class:`telegram.ext.Dispatcher`, provides a frontend to :class:`telegram.Bot` to the programmer, so they can focus on coding the bot. Its purpose is to @@ -85,6 +98,12 @@ class Updater: used). defaults (:class:`telegram.ext.Defaults`, optional): An object containing default values to be used if not set explicitly in the bot methods. + context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance + of :class:`telegram.ext.ContextTypes` to customize the types used in the + ``context`` interface. If not passed, the defaults documented in + :class:`telegram.ext.ContextTypes` will be used. + + .. versionadded:: 13.6 Raises: ValueError: If both :attr:`token` and :attr:`bot` are passed or none of them. @@ -124,7 +143,52 @@ class Updater: '__dict__', ) + @overload def __init__( + self: 'Updater[CallbackContext, dict, dict, dict]', + token: str = None, + base_url: str = None, + workers: int = 4, + bot: Bot = None, + private_key: bytes = None, + private_key_password: bytes = None, + user_sig_handler: Callable = None, + request_kwargs: Dict[str, Any] = None, + persistence: 'BasePersistence' = None, # pylint: disable=E0601 + defaults: 'Defaults' = None, + use_context: bool = True, + base_file_url: str = None, + ): + ... + + @overload + def __init__( + self: 'Updater[CCT, UD, CD, BD]', + token: str = None, + base_url: str = None, + workers: int = 4, + bot: Bot = None, + private_key: bytes = None, + private_key_password: bytes = None, + user_sig_handler: Callable = None, + request_kwargs: Dict[str, Any] = None, + persistence: 'BasePersistence' = None, + defaults: 'Defaults' = None, + use_context: bool = True, + base_file_url: str = None, + context_types: ContextTypes[CCT, UD, CD, BD] = None, + ): + ... + + @overload + def __init__( + self: 'Updater[CCT, UD, CD, BD]', + user_sig_handler: Callable = None, + dispatcher: Dispatcher[CCT, UD, CD, BD] = None, + ): + ... + + def __init__( # type: ignore[no-untyped-def,misc] self, token: str = None, base_url: str = None, @@ -137,8 +201,9 @@ class Updater: persistence: 'BasePersistence' = None, defaults: 'Defaults' = None, use_context: bool = True, - dispatcher: Dispatcher = None, + dispatcher=None, base_file_url: str = None, + context_types: ContextTypes[CCT, UD, CD, BD] = None, ): if defaults and bot: @@ -161,10 +226,12 @@ class Updater: raise ValueError('`dispatcher` and `bot` are mutually exclusive') if persistence is not None: raise ValueError('`dispatcher` and `persistence` are mutually exclusive') - if workers is not None: - raise ValueError('`dispatcher` and `workers` are mutually exclusive') if use_context != dispatcher.use_context: raise ValueError('`dispatcher` and `use_context` are mutually exclusive') + if context_types is not None: + raise ValueError('`dispatcher` and `context_types` are mutually exclusive') + if workers is not None: + raise ValueError('`dispatcher` and `workers` are mutually exclusive') self.logger = logging.getLogger(__name__) self._request = None @@ -212,6 +279,7 @@ class Updater: exception_event=self.__exception_event, persistence=persistence, use_context=use_context, + context_types=context_types, ) self.job_queue.set_dispatcher(self.dispatcher) else: diff --git a/telegram/ext/utils/types.py b/telegram/ext/utils/types.py new file mode 100644 index 000000000..fbaedd165 --- /dev/null +++ b/telegram/ext/utils/types.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2021 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains custom typing aliases. + +.. versionadded:: 13.6 +""" +from typing import TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from telegram.ext import CallbackContext # noqa: F401 + +CCT = TypeVar('CCT', bound='CallbackContext') +"""An instance of :class:`telegram.ext.CallbackContext` or a custom subclass. + +.. versionadded:: 13.6 +""" +UD = TypeVar('UD') +"""Type of the user data for a single user. + +.. versionadded:: 13.6 +""" +CD = TypeVar('CD') +"""Type of the chat data for a single user. + +.. versionadded:: 13.6 +""" +BD = TypeVar('BD') +"""Type of the bot data. + +.. versionadded:: 13.6 +""" diff --git a/telegram/utils/types.py b/telegram/utils/types.py index 1ab5f4df2..1ffcb2e44 100644 --- a/telegram/utils/types.py +++ b/telegram/utils/types.py @@ -18,11 +18,21 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains custom typing aliases.""" from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import ( + IO, + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, +) if TYPE_CHECKING: - from telegram import InputFile - from telegram.utils.helpers import DefaultValue + from telegram import InputFile # noqa: F401 + from telegram.utils.helpers import DefaultValue # noqa: F401 FileLike = Union[IO, 'InputFile'] """Either an open file handler or a :class:`telegram.InputFile`.""" diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py index cb9bc8034..ad4cfb387 100644 --- a/tests/test_callbackcontext.py +++ b/tests/test_callbackcontext.py @@ -22,6 +22,10 @@ import pytest from telegram import Update, Message, Chat, User, TelegramError from telegram.ext import CallbackContext +""" +CallbackContext.refresh_data is tested in TestBasePersistence +""" + class TestCallbackContext: def test_slot_behaviour(self, cdp, recwarn, mro_slots): diff --git a/tests/test_contexttypes.py b/tests/test_contexttypes.py new file mode 100644 index 000000000..20dd405f9 --- /dev/null +++ b/tests/test_contexttypes.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2021 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import pytest + +from telegram.ext import ContextTypes, CallbackContext + + +class SubClass(CallbackContext): + pass + + +class TestContextTypes: + def test_slot_behaviour(self, mro_slots): + instance = ContextTypes() + for attr in instance.__slots__: + assert getattr(instance, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(instance)) == len(set(mro_slots(instance))), "duplicate slot" + with pytest.raises(AttributeError): + instance.custom + + def test_data_init(self): + ct = ContextTypes(SubClass, int, float, bool) + assert ct.context is SubClass + assert ct.bot_data is int + assert ct.chat_data is float + assert ct.user_data is bool + + with pytest.raises(ValueError, match='subclass of CallbackContext'): + ContextTypes(context=bool) + + def test_data_assignment(self): + ct = ContextTypes() + + with pytest.raises(AttributeError): + ct.bot_data = bool + with pytest.raises(AttributeError): + ct.user_data = bool + with pytest.raises(AttributeError): + ct.chat_data = bool diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 672250999..bcadadcd5 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -32,6 +32,7 @@ from telegram.ext import ( CallbackContext, JobQueue, BasePersistence, + ContextTypes, ) from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop from telegram.utils.deprecate import TelegramDeprecationWarning @@ -45,6 +46,10 @@ def dp2(bot): yield from create_dp(bot) +class CustomContext(CallbackContext): + pass + + class TestDispatcher: message_update = Update( 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') @@ -747,6 +752,15 @@ class TestDispatcher: def update_conversation(self, name, key, new_state): pass + def refresh_bot_data(self, bot_data): + pass + + def refresh_user_data(self, user_id, user_data): + pass + + def refresh_chat_data(self, chat_id, chat_data): + pass + def callback(update, context): pass @@ -807,6 +821,15 @@ class TestDispatcher: def get_chat_data(self): pass + def refresh_bot_data(self, bot_data): + pass + + def refresh_user_data(self, user_id, user_data): + pass + + def refresh_chat_data(self, chat_id, chat_data): + pass + def callback(update, context): pass @@ -923,3 +946,62 @@ class TestDispatcher: assert self.count == expected finally: dp.bot.defaults = None + + def test_custom_context_init(self, bot): + cc = ContextTypes( + context=CustomContext, + user_data=int, + chat_data=float, + bot_data=complex, + ) + + dispatcher = Dispatcher(bot, Queue(), context_types=cc) + + assert isinstance(dispatcher.user_data[1], int) + assert isinstance(dispatcher.chat_data[1], float) + assert isinstance(dispatcher.bot_data, complex) + + def test_custom_context_error_handler(self, bot): + def error_handler(_, context): + self.received = ( + type(context), + type(context.user_data), + type(context.chat_data), + type(context.bot_data), + ) + + dispatcher = Dispatcher( + bot, + Queue(), + context_types=ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ), + ) + dispatcher.add_error_handler(error_handler) + dispatcher.add_handler(MessageHandler(Filters.all, self.callback_raise_error)) + + dispatcher.process_update(self.message_update) + sleep(0.1) + assert self.received == (CustomContext, float, complex, int) + + def test_custom_context_handler_callback(self, bot): + def callback(_, context): + self.received = ( + type(context), + type(context.user_data), + type(context.chat_data), + type(context.bot_data), + ) + + dispatcher = Dispatcher( + bot, + Queue(), + context_types=ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ), + ) + dispatcher.add_handler(MessageHandler(Filters.all, callback)) + + dispatcher.process_update(self.message_update) + sleep(0.1) + assert self.received == (CustomContext, float, complex, int) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index d0214a069..2851827dc 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -29,7 +29,11 @@ import pytest import pytz from apscheduler.schedulers import SchedulerNotRunningError from flaky import flaky -from telegram.ext import JobQueue, Updater, Job, CallbackContext +from telegram.ext import JobQueue, Updater, Job, CallbackContext, Dispatcher, ContextTypes + + +class CustomContext(CallbackContext): + pass @pytest.fixture(scope='function') @@ -519,3 +523,25 @@ class TestJobQueue: assert len(caplog.records) == 1 rec = caplog.records[-1] assert 'No error handlers are registered' in rec.getMessage() + + def test_custom_context(self, bot, job_queue): + dispatcher = Dispatcher( + bot, + Queue(), + context_types=ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ), + ) + job_queue.set_dispatcher(dispatcher) + + def callback(context): + self.result = ( + type(context), + context.user_data, + context.chat_data, + type(context.bot_data), + ) + + job_queue.run_once(callback, 0.1) + sleep(0.15) + assert self.result == (CustomContext, None, None, int) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 10bf010cb..0abe68c37 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -46,6 +46,7 @@ from telegram.ext import ( DictPersistence, TypeHandler, JobQueue, + ContextTypes, ) @@ -135,12 +136,16 @@ def bot_data(): @pytest.fixture(scope="function") def chat_data(): - return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}}) + return defaultdict( + dict, {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}} + ) @pytest.fixture(scope="function") def user_data(): - return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}}) + return defaultdict( + dict, {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}} + ) @pytest.fixture(scope='function') @@ -172,6 +177,12 @@ def job_queue(bot): class TestBasePersistence: + test_flag = False + + @pytest.fixture(scope='function', autouse=True) + def reset(self): + self.test_flag = False + def test_slot_behaviour(self, bot_persistence, mro_slots, recwarn): inst = bot_persistence for attr in inst.__slots__: @@ -254,8 +265,17 @@ class TestBasePersistence: u.dispatcher.chat_data[442233]['test5'] = 'test6' assert u.dispatcher.chat_data[442233]['test5'] == 'test6' + @pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async']) def test_dispatcher_integration_handlers( - self, caplog, bot, base_persistence, chat_data, user_data, bot_data + self, + cdp, + caplog, + bot, + base_persistence, + chat_data, + user_data, + bot_data, + run_async, ): def get_user_data(): return user_data @@ -269,112 +289,10 @@ class TestBasePersistence: base_persistence.get_user_data = get_user_data base_persistence.get_chat_data = get_chat_data base_persistence.get_bot_data = get_bot_data - # base_persistence.update_chat_data = lambda x: x - # base_persistence.update_user_data = lambda x: x - updater = Updater(bot=bot, persistence=base_persistence, use_context=True) - dp = updater.dispatcher + base_persistence.refresh_bot_data = lambda x: x + base_persistence.refresh_chat_data = lambda x, y: x + base_persistence.refresh_user_data = lambda x, y: x - def callback_known_user(update, context): - if not context.user_data['test1'] == 'test2': - pytest.fail('user_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - - def callback_known_chat(update, context): - if not context.chat_data['test3'] == 'test4': - pytest.fail('chat_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - - def callback_unknown_user_or_chat(update, context): - if not context.user_data == {}: - pytest.fail('user_data corrupt') - if not context.chat_data == {}: - pytest.fail('chat_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - context.user_data[1] = 'test7' - context.chat_data[2] = 'test8' - context.bot_data['test0'] = 'test0' - - known_user = MessageHandler( - Filters.user(user_id=12345), - callback_known_user, - pass_chat_data=True, - pass_user_data=True, - ) - known_chat = MessageHandler( - Filters.chat(chat_id=-67890), - callback_known_chat, - pass_chat_data=True, - pass_user_data=True, - ) - unknown = MessageHandler( - Filters.all, callback_unknown_user_or_chat, pass_chat_data=True, pass_user_data=True - ) - dp.add_handler(known_user) - dp.add_handler(known_chat) - dp.add_handler(unknown) - user1 = User(id=12345, first_name='test user', is_bot=False) - user2 = User(id=54321, first_name='test user', is_bot=False) - chat1 = Chat(id=-67890, type='group') - chat2 = Chat(id=-987654, type='group') - m = Message(1, None, chat2, from_user=user1) - u = Update(0, m) - with caplog.at_level(logging.ERROR): - dp.process_update(u) - rec = caplog.records[-1] - assert rec.getMessage() == 'No error handlers are registered, logging exception.' - assert rec.levelname == 'ERROR' - rec = caplog.records[-2] - assert rec.getMessage() == 'No error handlers are registered, logging exception.' - assert rec.levelname == 'ERROR' - rec = caplog.records[-3] - assert rec.getMessage() == 'No error handlers are registered, logging exception.' - assert rec.levelname == 'ERROR' - m.from_user = user2 - m.chat = chat1 - u = Update(1, m) - dp.process_update(u) - m.chat = chat2 - u = Update(2, m) - - def save_bot_data(data): - if 'test0' not in data: - pytest.fail() - - def save_chat_data(data): - if -987654 not in data: - pytest.fail() - - def save_user_data(data): - if 54321 not in data: - pytest.fail() - - base_persistence.update_chat_data = save_chat_data - base_persistence.update_user_data = save_user_data - base_persistence.update_bot_data = save_bot_data - dp.process_update(u) - - assert dp.user_data[54321][1] == 'test7' - assert dp.chat_data[-987654][2] == 'test8' - assert dp.bot_data['test0'] == 'test0' - - def test_dispatcher_integration_handlers_run_async( - self, cdp, caplog, bot, base_persistence, chat_data, user_data, bot_data - ): - def get_user_data(): - return user_data - - def get_chat_data(): - return chat_data - - def get_bot_data(): - return bot_data - - base_persistence.get_user_data = get_user_data - base_persistence.get_chat_data = get_chat_data - base_persistence.get_bot_data = get_bot_data cdp.persistence = base_persistence cdp.user_data = user_data cdp.chat_data = chat_data @@ -408,21 +326,21 @@ class TestBasePersistence: callback_known_user, pass_chat_data=True, pass_user_data=True, - run_async=True, + run_async=run_async, ) known_chat = MessageHandler( Filters.chat(chat_id=-67890), callback_known_chat, pass_chat_data=True, pass_user_data=True, - run_async=True, + run_async=run_async, ) unknown = MessageHandler( Filters.all, callback_unknown_user_or_chat, pass_chat_data=True, pass_user_data=True, - run_async=True, + run_async=run_async, ) cdp.add_handler(known_user) cdp.add_handler(known_chat) @@ -437,12 +355,16 @@ class TestBasePersistence: cdp.process_update(u) sleep(0.1) - rec = caplog.records[-1] - assert rec.getMessage() == 'No error handlers are registered, logging exception.' - assert rec.levelname == 'ERROR' - rec = caplog.records[-2] - assert rec.getMessage() == 'No error handlers are registered, logging exception.' - assert rec.levelname == 'ERROR' + + # In base_persistence.update_*_data we currently just raise NotImplementedError + # This makes sure that this doesn't break the processing and is properly handled by + # the error handler + # We override `update_*_data` further below. + assert len(caplog.records) == 3 + for rec in caplog.records: + assert rec.getMessage() == 'No error handlers are registered, logging exception.' + assert rec.levelname == 'ERROR' + m.from_user = user2 m.chat = chat1 u = Update(1, m) @@ -473,6 +395,105 @@ class TestBasePersistence: assert cdp.chat_data[-987654][2] == 'test8' assert cdp.bot_data['test0'] == 'test0' + @pytest.mark.parametrize( + 'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False'] + ) + @pytest.mark.parametrize( + 'store_chat_data', [True, False], ids=['store_chat_data-True', 'store_chat_data-False'] + ) + @pytest.mark.parametrize( + 'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False'] + ) + @pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async']) + def test_persistence_dispatcher_integration_refresh_data( + self, + cdp, + base_persistence, + chat_data, + bot_data, + user_data, + store_bot_data, + store_chat_data, + store_user_data, + run_async, + ): + base_persistence.refresh_bot_data = lambda x: x.setdefault( + 'refreshed', x.get('refreshed', 0) + 1 + ) + # x is the user/chat_id + base_persistence.refresh_chat_data = lambda x, y: y.setdefault('refreshed', x) + base_persistence.refresh_user_data = lambda x, y: y.setdefault('refreshed', x) + base_persistence.store_bot_data = store_bot_data + base_persistence.store_chat_data = store_chat_data + base_persistence.store_user_data = store_user_data + cdp.persistence = base_persistence + + self.test_flag = True + + def callback_with_user_and_chat(update, context): + if store_user_data: + if context.user_data.get('refreshed') != update.effective_user.id: + self.test_flag = 'user_data was not refreshed' + else: + if 'refreshed' in context.user_data: + self.test_flag = 'user_data was wrongly refreshed' + if store_chat_data: + if context.chat_data.get('refreshed') != update.effective_chat.id: + self.test_flag = 'chat_data was not refreshed' + else: + if 'refreshed' in context.chat_data: + self.test_flag = 'chat_data was wrongly refreshed' + if store_bot_data: + if context.bot_data.get('refreshed') != 1: + self.test_flag = 'bot_data was not refreshed' + else: + if 'refreshed' in context.bot_data: + self.test_flag = 'bot_data was wrongly refreshed' + + def callback_without_user_and_chat(_, context): + if store_bot_data: + if context.bot_data.get('refreshed') != 1: + self.test_flag = 'bot_data was not refreshed' + else: + if 'refreshed' in context.bot_data: + self.test_flag = 'bot_data was wrongly refreshed' + + with_user_and_chat = MessageHandler( + Filters.user(user_id=12345), + callback_with_user_and_chat, + pass_chat_data=True, + pass_user_data=True, + run_async=run_async, + ) + without_user_and_chat = MessageHandler( + Filters.all, + callback_without_user_and_chat, + pass_chat_data=True, + pass_user_data=True, + run_async=run_async, + ) + cdp.add_handler(with_user_and_chat) + cdp.add_handler(without_user_and_chat) + user = User(id=12345, first_name='test user', is_bot=False) + chat = Chat(id=-987654, type='group') + m = Message(1, None, chat, from_user=user) + + # has user and chat + u = Update(0, m) + cdp.process_update(u) + + assert self.test_flag is True + + # has neither user nor hat + m.from_user = None + m.chat = None + u = Update(1, m) + cdp.process_update(u) + + assert self.test_flag is True + + sleep(0.1) + def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog): # Updates used with TypeHandler doesn't necessarily have the proper attributes for # persistence, makes sure it works anyways @@ -816,6 +837,10 @@ def update(bot): return Update(0, message=message) +class CustomMapping(defaultdict): + pass + + class TestPicklePersistence: def test_slot_behaviour(self, mro_slots, recwarn, pickle_persistence): inst = pickle_persistence @@ -986,25 +1011,34 @@ class TestPicklePersistence: def test_updating_multi_file(self, pickle_persistence, good_pickle_files): user_data = pickle_persistence.get_user_data() - user_data[54321]['test9'] = 'test 10' + user_data[12345]['test3']['test4'] = 'test6' assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(54321, user_data[54321]) + pickle_persistence.update_user_data(12345, user_data[12345]) + user_data[12345]['test3']['test4'] = 'test7' + assert not pickle_persistence.user_data == user_data + pickle_persistence.update_user_data(12345, user_data[12345]) assert pickle_persistence.user_data == user_data with open('pickletest_user_data', 'rb') as f: user_data_test = defaultdict(dict, pickle.load(f)) assert user_data_test == user_data chat_data = pickle_persistence.get_chat_data() - chat_data[54321]['test9'] = 'test 10' + chat_data[-12345]['test3']['test4'] = 'test6' assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(54321, chat_data[54321]) + pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + chat_data[-12345]['test3']['test4'] = 'test7' + assert not pickle_persistence.chat_data == chat_data + pickle_persistence.update_chat_data(-12345, chat_data[-12345]) assert pickle_persistence.chat_data == chat_data with open('pickletest_chat_data', 'rb') as f: chat_data_test = defaultdict(dict, pickle.load(f)) assert chat_data_test == chat_data bot_data = pickle_persistence.get_bot_data() - bot_data['test6'] = 'test 7' + bot_data['test3']['test4'] = 'test6' + assert not pickle_persistence.bot_data == bot_data + pickle_persistence.update_bot_data(bot_data) + bot_data['test3']['test4'] = 'test7' assert not pickle_persistence.bot_data == bot_data pickle_persistence.update_bot_data(bot_data) assert pickle_persistence.bot_data == bot_data @@ -1031,25 +1065,34 @@ class TestPicklePersistence: pickle_persistence.single_file = True user_data = pickle_persistence.get_user_data() - user_data[54321]['test9'] = 'test 10' + user_data[12345]['test3']['test4'] = 'test6' assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(54321, user_data[54321]) + pickle_persistence.update_user_data(12345, user_data[12345]) + user_data[12345]['test3']['test4'] = 'test7' + assert not pickle_persistence.user_data == user_data + pickle_persistence.update_user_data(12345, user_data[12345]) assert pickle_persistence.user_data == user_data with open('pickletest', 'rb') as f: user_data_test = defaultdict(dict, pickle.load(f)['user_data']) assert user_data_test == user_data chat_data = pickle_persistence.get_chat_data() - chat_data[54321]['test9'] = 'test 10' + chat_data[-12345]['test3']['test4'] = 'test6' assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(54321, chat_data[54321]) + pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + chat_data[-12345]['test3']['test4'] = 'test7' + assert not pickle_persistence.chat_data == chat_data + pickle_persistence.update_chat_data(-12345, chat_data[-12345]) assert pickle_persistence.chat_data == chat_data with open('pickletest', 'rb') as f: chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) assert chat_data_test == chat_data bot_data = pickle_persistence.get_bot_data() - bot_data['test6'] = 'test 7' + bot_data['test3']['test4'] = 'test6' + assert not pickle_persistence.bot_data == bot_data + pickle_persistence.update_bot_data(bot_data) + bot_data['test3']['test4'] = 'test7' assert not pickle_persistence.bot_data == bot_data pickle_persistence.update_bot_data(bot_data) assert pickle_persistence.bot_data == bot_data @@ -1418,6 +1461,39 @@ class TestPicklePersistence: user_data = pickle_persistence.get_user_data() assert user_data[789] == {'test3': '123'} + @pytest.mark.parametrize('singlefile', [True, False]) + @pytest.mark.parametrize('ud', [int, float, complex]) + @pytest.mark.parametrize('cd', [int, float, complex]) + @pytest.mark.parametrize('bd', [int, float, complex]) + def test_with_context_types(self, ud, cd, bd, singlefile): + cc = ContextTypes(user_data=ud, chat_data=cd, bot_data=bd) + persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) + + assert isinstance(persistence.get_user_data()[1], ud) + assert persistence.get_user_data()[1] == 0 + assert isinstance(persistence.get_chat_data()[1], cd) + assert persistence.get_chat_data()[1] == 0 + assert isinstance(persistence.get_bot_data(), bd) + assert persistence.get_bot_data() == 0 + + persistence.user_data = None + persistence.chat_data = None + persistence.update_user_data(1, ud(1)) + persistence.update_chat_data(1, cd(1)) + persistence.update_bot_data(bd(1)) + assert persistence.get_user_data()[1] == 1 + assert persistence.get_chat_data()[1] == 1 + assert persistence.get_bot_data() == 1 + + persistence.flush() + persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) + assert isinstance(persistence.get_user_data()[1], ud) + assert persistence.get_user_data()[1] == 1 + assert isinstance(persistence.get_chat_data()[1], cd) + assert persistence.get_chat_data()[1] == 1 + assert isinstance(persistence.get_bot_data(), bd) + assert persistence.get_bot_data() == 1 + @pytest.fixture(scope='function') def user_data_json(user_data): @@ -1560,7 +1636,7 @@ class TestDictPersistence: assert dict_persistence.bot_data_json == bot_data_json assert dict_persistence.conversations_json == conversations_json - def test_json_changes( + def test_updating( self, user_data, user_data_json, @@ -1577,35 +1653,59 @@ class TestDictPersistence: bot_data_json=bot_data_json, conversations_json=conversations_json, ) - user_data_two = user_data.copy() - user_data_two.update({4: {5: 6}}) - dict_persistence.update_user_data(4, {5: 6}) - assert dict_persistence.user_data == user_data_two - assert dict_persistence.user_data_json != user_data_json - assert dict_persistence.user_data_json == json.dumps(user_data_two) - chat_data_two = chat_data.copy() - chat_data_two.update({7: {8: 9}}) - dict_persistence.update_chat_data(7, {8: 9}) - assert dict_persistence.chat_data == chat_data_two - assert dict_persistence.chat_data_json != chat_data_json - assert dict_persistence.chat_data_json == json.dumps(chat_data_two) + user_data = dict_persistence.get_user_data() + user_data[12345]['test3']['test4'] = 'test6' + assert not dict_persistence.user_data == user_data + assert not dict_persistence.user_data_json == json.dumps(user_data) + dict_persistence.update_user_data(12345, user_data[12345]) + user_data[12345]['test3']['test4'] = 'test7' + assert not dict_persistence.user_data == user_data + assert not dict_persistence.user_data_json == json.dumps(user_data) + dict_persistence.update_user_data(12345, user_data[12345]) + assert dict_persistence.user_data == user_data + assert dict_persistence.user_data_json == json.dumps(user_data) - bot_data_two = bot_data.copy() - bot_data_two.update({'7': {'8': '9'}}) - bot_data['7'] = {'8': '9'} + chat_data = dict_persistence.get_chat_data() + chat_data[-12345]['test3']['test4'] = 'test6' + assert not dict_persistence.chat_data == chat_data + assert not dict_persistence.chat_data_json == json.dumps(chat_data) + dict_persistence.update_chat_data(-12345, chat_data[-12345]) + chat_data[-12345]['test3']['test4'] = 'test7' + assert not dict_persistence.chat_data == chat_data + assert not dict_persistence.chat_data_json == json.dumps(chat_data) + dict_persistence.update_chat_data(-12345, chat_data[-12345]) + assert dict_persistence.chat_data == chat_data + assert dict_persistence.chat_data_json == json.dumps(chat_data) + + bot_data = dict_persistence.get_bot_data() + bot_data['test3']['test4'] = 'test6' + assert not dict_persistence.bot_data == bot_data + assert not dict_persistence.bot_data_json == json.dumps(bot_data) dict_persistence.update_bot_data(bot_data) - assert dict_persistence.bot_data == bot_data_two - assert dict_persistence.bot_data_json != bot_data_json - assert dict_persistence.bot_data_json == json.dumps(bot_data_two) + bot_data['test3']['test4'] = 'test7' + assert not dict_persistence.bot_data == bot_data + assert not dict_persistence.bot_data_json == json.dumps(bot_data) + dict_persistence.update_bot_data(bot_data) + assert dict_persistence.bot_data == bot_data + assert dict_persistence.bot_data_json == json.dumps(bot_data) - conversations_two = conversations.copy() - conversations_two.update({'name4': {(1, 2): 3}}) - dict_persistence.update_conversation('name4', (1, 2), 3) - assert dict_persistence.conversations == conversations_two - assert dict_persistence.conversations_json != conversations_json + conversation1 = dict_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not dict_persistence.conversations['name1'] == conversation1 + dict_persistence.update_conversation('name1', (123, 123), 5) + assert dict_persistence.conversations['name1'] == conversation1 + print(dict_persistence.conversations_json) + conversations['name1'][(123, 123)] = 5 + assert dict_persistence.conversations_json == encode_conversations_to_json(conversations) + assert dict_persistence.get_conversations('name1') == conversation1 + + dict_persistence._conversations = None + dict_persistence.update_conversation('name1', (123, 123), 5) + assert dict_persistence.conversations['name1'] == {(123, 123): 5} + assert dict_persistence.get_conversations('name1') == {(123, 123): 5} assert dict_persistence.conversations_json == encode_conversations_to_json( - conversations_two + {"name1": {(123, 123): 5}} ) def test_with_handler(self, bot, update): diff --git a/tests/test_slots.py b/tests/test_slots.py index e97a4e178..9d5169eb3 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -31,6 +31,7 @@ excluded = { 'Days', 'telegram.deprecate', 'TelegramDecryptionError', + 'ContextTypes', } # These modules/classes intentionally don't have __dict__. diff --git a/tests/test_updater.py b/tests/test_updater.py index 16c3c611a..9eda467a6 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -613,6 +613,11 @@ class TestUpdater: with pytest.raises(ValueError): Updater(dispatcher=dispatcher, use_context=use_context) + def test_mutual_exclude_custom_context_dispatcher(self): + dispatcher = Dispatcher(None, None) + with pytest.raises(ValueError): + Updater(dispatcher=dispatcher, context_types=True) + def test_defaults_warning(self, bot): with pytest.warns(TelegramDeprecationWarning, match='no effect when a Bot is passed'): Updater(bot=bot, defaults=Defaults())