Add ContextTypes & BasePersistence.refresh_user/chat/bot_data (#2262)

This commit is contained in:
Bibo-Joshi 2021-06-06 10:37:53 +02:00 committed by GitHub
parent 5da1dd7ce9
commit fce7cc903c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 1372 additions and 354 deletions

View file

@ -60,7 +60,7 @@ jobs:
shell: bash --noprofile --norc {0} shell: bash --noprofile --norc {0}
- name: Submit coverage - name: Submit coverage
uses: codecov/codecov-action@v1.0.13 uses: codecov/codecov-action@v1
with: with:
env_vars: OS,PYTHON env_vars: OS,PYTHON
name: ${{ matrix.os }}-${{ matrix.python-version }} name: ${{ matrix.os }}-${{ matrix.python-version }}
@ -79,7 +79,7 @@ jobs:
run: run:
git submodule update --init --recursive git submodule update --init --recursive
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
@ -108,7 +108,7 @@ jobs:
run: run:
git submodule update --init --recursive git submodule update --init --recursive
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies

View file

@ -0,0 +1,8 @@
:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/contexttypes.py
telegram.ext.ContextTypes
=========================
.. autoclass:: telegram.ext.ContextTypes
:members:
:show-inheritance:

View file

@ -7,11 +7,12 @@ telegram.ext package
telegram.ext.dispatcher telegram.ext.dispatcher
telegram.ext.dispatcherhandlerstop telegram.ext.dispatcherhandlerstop
telegram.ext.callbackcontext telegram.ext.callbackcontext
telegram.ext.defaults
telegram.ext.job telegram.ext.job
telegram.ext.jobqueue telegram.ext.jobqueue
telegram.ext.messagequeue telegram.ext.messagequeue
telegram.ext.delayqueue telegram.ext.delayqueue
telegram.ext.contexttypes
telegram.ext.defaults
Handlers Handlers
-------- --------
@ -52,3 +53,4 @@ utils
.. toctree:: .. toctree::
telegram.ext.utils.promise telegram.ext.utils.promise
telegram.ext.utils.types

View file

@ -0,0 +1,8 @@
:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/utils/types.py
telegram.ext.utils.types Module
================================
.. automodule:: telegram.ext.utils.types
:members:
:show-inheritance:

View file

@ -49,5 +49,8 @@ A basic example on how to set up a custom error handler.
### [`chatmemberbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/chatmemberbot.py) ### [`chatmemberbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/chatmemberbot.py)
A basic example on how `(my_)chat_member` updates can be used. A basic example on how `(my_)chat_member` updates can be used.
### [`contexttypesbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/contexttypesbot.py)
This example showcases how `telegram.ext.ContextTypes` can be used to customize the `context` argument of handler and job callbacks.
## Pure API ## Pure API
The [`rawapibot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/rawapibot.py) example uses only the pure, "bare-metal" API wrapper. The [`rawapibot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/rawapibot.py) example uses only the pure, "bare-metal" API wrapper.

129
examples/contexttypesbot.py Normal file
View file

@ -0,0 +1,129 @@
#!/usr/bin/env python
# pylint: disable=C0116,W0613
# This program is dedicated to the public domain under the CC0 license.
"""
Simple Bot to showcase `telegram.ext.ContextTypes`.
Usage:
Press Ctrl-C on the command line or send a signal to the process to stop the
bot.
"""
from collections import defaultdict
from typing import DefaultDict, Optional, Set
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, ParseMode
from telegram.ext import (
Updater,
CommandHandler,
CallbackContext,
ContextTypes,
CallbackQueryHandler,
TypeHandler,
Dispatcher,
)
class ChatData:
"""Custom class for chat_data. Here we store data per message."""
def __init__(self) -> None:
self.clicks_per_message: DefaultDict[int, int] = defaultdict(int)
# The [dict, ChatData, dict] is for type checkers like mypy
class CustomContext(CallbackContext[dict, ChatData, dict]):
"""Custom class for context."""
def __init__(self, dispatcher: Dispatcher):
super().__init__(dispatcher=dispatcher)
self._message_id: Optional[int] = None
@property
def bot_user_ids(self) -> Set[int]:
"""Custom shortcut to access a value stored in the bot_data dict"""
return self.bot_data.setdefault('user_ids', set())
@property
def message_clicks(self) -> Optional[int]:
"""Access the number of clicks for the message this context object was built for."""
if self._message_id:
return self.chat_data.clicks_per_message[self._message_id]
return None
@message_clicks.setter
def message_clicks(self, value: int) -> None:
"""Allow to change the count"""
if not self._message_id:
raise RuntimeError('There is no message associated with this context obejct.')
self.chat_data.clicks_per_message[self._message_id] = value
@classmethod
def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CustomContext':
"""Override from_update to set _message_id."""
# Make sure to call super()
context = super().from_update(update, dispatcher)
if context.chat_data and isinstance(update, Update) and update.effective_message:
context._message_id = update.effective_message.message_id # pylint: disable=W0212
# Remember to return the object
return context
def start(update: Update, context: CustomContext) -> None:
"""Display a message with a button."""
update.message.reply_html(
'This button was clicked <i>0</i> times.',
reply_markup=InlineKeyboardMarkup.from_button(
InlineKeyboardButton(text='Click me!', callback_data='button')
),
)
def count_click(update: Update, context: CustomContext) -> None:
"""Update the click count for the message."""
context.message_clicks += 1
update.callback_query.answer()
update.effective_message.edit_text(
f'This button was clicked <i>{context.message_clicks}</i> times.',
reply_markup=InlineKeyboardMarkup.from_button(
InlineKeyboardButton(text='Click me!', callback_data='button')
),
parse_mode=ParseMode.HTML,
)
def print_users(update: Update, context: CustomContext) -> None:
"""Show which users have been using this bot."""
update.message.reply_text(
'The following user IDs have used this bot: '
f'{", ".join(map(str, context.bot_user_ids))}'
)
def track_users(update: Update, context: CustomContext) -> None:
"""Store the user id of the incoming update, if any."""
if update.effective_user:
context.bot_user_ids.add(update.effective_user.id)
def main() -> None:
"""Run the bot."""
context_types = ContextTypes(context=CustomContext, chat_data=ChatData)
updater = Updater("TOKEN", context_types=context_types)
dispatcher = updater.dispatcher
# run track_users in its own group to not interfere with the user handlers
dispatcher.add_handler(TypeHandler(Update, track_users), group=-1)
dispatcher.add_handler(CommandHandler("start", start))
dispatcher.add_handler(CallbackQueryHandler(count_click))
dispatcher.add_handler(CommandHandler("print_users", print_users))
updater.start_polling()
updater.idle()
if __name__ == '__main__':
main()

View file

@ -44,6 +44,7 @@ omit =
[coverage:report] [coverage:report]
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
...
[mypy] [mypy]
warn_unused_ignores = True warn_unused_ignores = True

View file

@ -16,6 +16,7 @@
# #
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
# pylint: disable=C0413
"""Extensions over the Telegram Bot API to facilitate bot making""" """Extensions over the Telegram Bot API to facilitate bot making"""
from .basepersistence import BasePersistence from .basepersistence import BasePersistence
@ -23,7 +24,20 @@ from .picklepersistence import PicklePersistence
from .dictpersistence import DictPersistence from .dictpersistence import DictPersistence
from .handler import Handler from .handler import Handler
from .callbackcontext import CallbackContext from .callbackcontext import CallbackContext
from .contexttypes import ContextTypes
from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async
# https://bugs.python.org/issue41451, fixed on 3.7+, doesn't actually remove slots
# try-except is just here in case the __init__ is called twice (like in the tests)
# this block is also the reason for the pylint-ignore at the top of the file
try:
del Dispatcher.__slots__ # type: ignore[has-type]
except AttributeError as exc:
if str(exc) == '__slots__':
pass
else:
raise exc
from .jobqueue import JobQueue, Job from .jobqueue import JobQueue, Job
from .updater import Updater from .updater import Updater
from .callbackqueryhandler import CallbackQueryHandler from .callbackqueryhandler import CallbackQueryHandler
@ -47,38 +61,39 @@ from .chatmemberhandler import ChatMemberHandler
from .defaults import Defaults from .defaults import Defaults
__all__ = ( __all__ = (
'Dispatcher', 'BaseFilter',
'JobQueue', 'BasePersistence',
'Job', 'CallbackContext',
'Updater',
'CallbackQueryHandler', 'CallbackQueryHandler',
'ChatMemberHandler',
'ChosenInlineResultHandler', 'ChosenInlineResultHandler',
'CommandHandler', 'CommandHandler',
'ContextTypes',
'ConversationHandler',
'Defaults',
'DelayQueue',
'DictPersistence',
'Dispatcher',
'DispatcherHandlerStop',
'Filters',
'Handler', 'Handler',
'InlineQueryHandler', 'InlineQueryHandler',
'MessageHandler', 'Job',
'BaseFilter', 'JobQueue',
'MessageFilter', 'MessageFilter',
'UpdateFilter', 'MessageHandler',
'Filters', 'MessageQueue',
'PicklePersistence',
'PollAnswerHandler',
'PollHandler',
'PreCheckoutQueryHandler',
'PrefixHandler',
'RegexHandler', 'RegexHandler',
'ShippingQueryHandler',
'StringCommandHandler', 'StringCommandHandler',
'StringRegexHandler', 'StringRegexHandler',
'TypeHandler', 'TypeHandler',
'ConversationHandler', 'UpdateFilter',
'PreCheckoutQueryHandler', 'Updater',
'ShippingQueryHandler',
'MessageQueue',
'DelayQueue',
'DispatcherHandlerStop',
'run_async', 'run_async',
'CallbackContext',
'BasePersistence',
'PicklePersistence',
'DictPersistence',
'PrefixHandler',
'PollAnswerHandler',
'PollHandler',
'ChatMemberHandler',
'Defaults',
) )

View file

@ -21,16 +21,17 @@ import warnings
from sys import version_info as py_ver from sys import version_info as py_ver
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import copy from copy import copy
from typing import DefaultDict, Dict, Optional, Tuple, cast, ClassVar from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, DefaultDict
from telegram.utils.deprecate import set_new_attribute_deprecated from telegram.utils.deprecate import set_new_attribute_deprecated
from telegram import Bot from telegram import Bot
from telegram.utils.types import ConversationDict from telegram.utils.types import ConversationDict
from telegram.ext.utils.types import UD, CD, BD
class BasePersistence(ABC): class BasePersistence(Generic[UD, CD, BD], ABC):
"""Interface class for adding persistence to your bot. """Interface class for adding persistence to your bot.
Subclass this object for different implementations of a persistent bot. Subclass this object for different implementations of a persistent bot.
@ -38,16 +39,20 @@ class BasePersistence(ABC):
* :meth:`get_bot_data` * :meth:`get_bot_data`
* :meth:`update_bot_data` * :meth:`update_bot_data`
* :meth:`refresh_bot_data`
* :meth:`get_chat_data` * :meth:`get_chat_data`
* :meth:`update_chat_data` * :meth:`update_chat_data`
* :meth:`refresh_chat_data`
* :meth:`get_user_data` * :meth:`get_user_data`
* :meth:`update_user_data` * :meth:`update_user_data`
* :meth:`refresh_user_data`
* :meth:`get_conversations` * :meth:`get_conversations`
* :meth:`update_conversation` * :meth:`update_conversation`
* :meth:`flush` * :meth:`flush`
If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if
``store_bot_data=False``, you don't need :meth:`get_bot_data` and :meth:`update_bot_data`. ``store_bot_data=False``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or
:meth:`refresh_bot_data`.
Warning: Warning:
Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and
@ -93,6 +98,10 @@ class BasePersistence(ABC):
def __new__( def __new__(
cls, *args: object, **kwargs: object # pylint: disable=W0613 cls, *args: object, **kwargs: object # pylint: disable=W0613
) -> 'BasePersistence': ) -> 'BasePersistence':
"""This overrides the get_* and update_* methods to use insert/replace_bot.
That has the side effect that we always pass deepcopied data to those methods, so in
Pickle/DictPersistence we don't have to worry about copying the data again.
"""
instance = super().__new__(cls) instance = super().__new__(cls)
get_user_data = instance.get_user_data get_user_data = instance.get_user_data
get_chat_data = instance.get_chat_data get_chat_data = instance.get_chat_data
@ -101,22 +110,22 @@ class BasePersistence(ABC):
update_chat_data = instance.update_chat_data update_chat_data = instance.update_chat_data
update_bot_data = instance.update_bot_data update_bot_data = instance.update_bot_data
def get_user_data_insert_bot() -> DefaultDict[int, Dict[object, object]]: def get_user_data_insert_bot() -> DefaultDict[int, UD]:
return instance.insert_bot(get_user_data()) return instance.insert_bot(get_user_data())
def get_chat_data_insert_bot() -> DefaultDict[int, Dict[object, object]]: def get_chat_data_insert_bot() -> DefaultDict[int, CD]:
return instance.insert_bot(get_chat_data()) return instance.insert_bot(get_chat_data())
def get_bot_data_insert_bot() -> Dict[object, object]: def get_bot_data_insert_bot() -> BD:
return instance.insert_bot(get_bot_data()) return instance.insert_bot(get_bot_data())
def update_user_data_replace_bot(user_id: int, data: Dict) -> None: def update_user_data_replace_bot(user_id: int, data: UD) -> None:
return update_user_data(user_id, instance.replace_bot(data)) return update_user_data(user_id, instance.replace_bot(data))
def update_chat_data_replace_bot(chat_id: int, data: Dict) -> None: def update_chat_data_replace_bot(chat_id: int, data: CD) -> None:
return update_chat_data(chat_id, instance.replace_bot(data)) return update_chat_data(chat_id, instance.replace_bot(data))
def update_bot_data_replace_bot(data: Dict) -> None: def update_bot_data_replace_bot(data: BD) -> None:
return update_bot_data(instance.replace_bot(data)) return update_bot_data(instance.replace_bot(data))
# We want to ignore TGDeprecation warnings so we use obj.__setattr__. Adds to __dict__ # We want to ignore TGDeprecation warnings so we use obj.__setattr__. Adds to __dict__
@ -334,33 +343,33 @@ class BasePersistence(ABC):
return obj return obj
@abstractmethod @abstractmethod
def get_user_data(self) -> DefaultDict[int, Dict[object, object]]: def get_user_data(self) -> DefaultDict[int, UD]:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``user_data`` if stored, or an empty persistence object. It should return the ``user_data`` if stored, or an empty
``defaultdict(dict)``. :obj:`defaultdict(telegram.ext.utils.types.UD)` with integer keys.
Returns: Returns:
:obj:`defaultdict`: The restored user data. DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data.
""" """
@abstractmethod @abstractmethod
def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: def get_chat_data(self) -> DefaultDict[int, CD]:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``chat_data`` if stored, or an empty persistence object. It should return the ``chat_data`` if stored, or an empty
``defaultdict(dict)``. :obj:`defaultdict(telegram.ext.utils.types.CD)` with integer keys.
Returns: Returns:
:obj:`defaultdict`: The restored chat data. DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data.
""" """
@abstractmethod @abstractmethod
def get_bot_data(self) -> Dict[object, object]: def get_bot_data(self) -> BD:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``bot_data`` if stored, or an empty persistence object. It should return the ``bot_data`` if stored, or an empty
:obj:`dict`. :class:`telegram.ext.utils.types.BD`.
Returns: Returns:
:obj:`dict`: The restored bot data. :class:`telegram.ext.utils.types.BD`: The restored bot data.
""" """
@abstractmethod @abstractmethod
@ -391,32 +400,70 @@ class BasePersistence(ABC):
""" """
@abstractmethod @abstractmethod
def update_user_data(self, user_id: int, data: Dict) -> None: def update_user_data(self, user_id: int, data: UD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update. handled an update.
Args: Args:
user_id (:obj:`int`): The user the data might have been changed for. user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. data (:class:`telegram.ext.utils.types.UD`): The
:attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
""" """
@abstractmethod @abstractmethod
def update_chat_data(self, chat_id: int, data: Dict) -> None: def update_chat_data(self, chat_id: int, data: CD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update. handled an update.
Args: Args:
chat_id (:obj:`int`): The chat the data might have been changed for. chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. data (:class:`telegram.ext.utils.types.CD`): The
:attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
""" """
@abstractmethod @abstractmethod
def update_bot_data(self, data: Dict) -> None: def update_bot_data(self, data: BD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update. handled an update.
Args: Args:
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data` . data (:class:`telegram.ext.utils.types.BD`): The
:attr:`telegram.ext.dispatcher.bot_data`.
"""
def refresh_user_data(self, user_id: int, user_data: UD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` before passing the
:attr:`user_data` to a callback. Can be used to update data stored in :attr:`user_data`
from an external source.
.. versionadded:: 13.6
Args:
user_id (:obj:`int`): The user ID this :attr:`user_data` is associated with.
user_data (:class:`telegram.ext.utils.types.UD`): The ``user_data`` of a single user.
"""
def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` before passing the
:attr:`chat_data` to a callback. Can be used to update data stored in :attr:`chat_data`
from an external source.
.. versionadded:: 13.6
Args:
chat_id (:obj:`int`): The chat ID this :attr:`chat_data` is associated with.
chat_data (:class:`telegram.ext.utils.types.CD`): The ``chat_data`` of a single chat.
"""
def refresh_bot_data(self, bot_data: BD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` before passing the
:attr:`bot_data` to a callback. Can be used to update data stored in :attr:`bot_data`
from an external source.
.. versionadded:: 13.6
Args:
bot_data (:class:`telegram.ext.utils.types.BD`): The ``bot_data``.
""" """
def flush(self) -> None: def flush(self) -> None:

View file

@ -19,16 +19,31 @@
# pylint: disable=R0201 # pylint: disable=R0201
"""This module contains the CallbackContext class.""" """This module contains the CallbackContext class."""
from queue import Queue from queue import Queue
from typing import TYPE_CHECKING, Dict, List, Match, NoReturn, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Dict,
List,
Match,
NoReturn,
Optional,
Tuple,
Union,
Generic,
Type,
TypeVar,
)
from telegram import Update from telegram import Update
from telegram.ext.utils.types import UD, CD, BD
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram import Bot from telegram import Bot
from telegram.ext import Dispatcher, Job, JobQueue from telegram.ext import Dispatcher, Job, JobQueue
CC = TypeVar('CC', bound='CallbackContext')
class CallbackContext:
class CallbackContext(Generic[UD, CD, BD]):
""" """
This is a context object passed to the callback called by :class:`telegram.ext.Handler` This is a context object passed to the callback called by :class:`telegram.ext.Handler`
or by the :class:`telegram.ext.Dispatcher` in an error handler added by or by the :class:`telegram.ext.Dispatcher` in an error handler added by
@ -50,6 +65,9 @@ class CallbackContext:
almost certainly execute the callbacks for an update out of order, and the attributes almost certainly execute the callbacks for an update out of order, and the attributes
that you think you added will not be present. that you think you added will not be present.
Args:
dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this context.
Attributes: Attributes:
matches (List[:obj:`re match object`]): Optional. If the associated update originated from matches (List[:obj:`re match object`]): Optional. If the associated update originated from
a regex-supported handler or had a :class:`Filters.regex`, this will contain a list of a regex-supported handler or had a :class:`Filters.regex`, this will contain a list of
@ -75,9 +93,8 @@ class CallbackContext:
__slots__ = ( __slots__ = (
'_dispatcher', '_dispatcher',
'_bot_data', '_chat_id_and_data',
'_chat_data', '_user_id_and_data',
'_user_data',
'args', 'args',
'matches', 'matches',
'error', 'error',
@ -97,9 +114,8 @@ class CallbackContext:
'CallbackContext should not be used with a non context aware ' 'dispatcher!' 'CallbackContext should not be used with a non context aware ' 'dispatcher!'
) )
self._dispatcher = dispatcher self._dispatcher = dispatcher
self._bot_data = dispatcher.bot_data self._chat_id_and_data: Optional[Tuple[int, CD]] = None
self._chat_data: Optional[Dict[object, object]] = None self._user_id_and_data: Optional[Tuple[int, UD]] = None
self._user_data: Optional[Dict[object, object]] = None
self.args: Optional[List[str]] = None self.args: Optional[List[str]] = None
self.matches: Optional[List[Match]] = None self.matches: Optional[List[Match]] = None
self.error: Optional[Exception] = None self.error: Optional[Exception] = None
@ -113,11 +129,11 @@ class CallbackContext:
return self._dispatcher return self._dispatcher
@property @property
def bot_data(self) -> Dict: def bot_data(self) -> BD:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update it will be the same ``dict``. update it will be the same ``dict``.
""" """
return self._bot_data return self.dispatcher.bot_data
@bot_data.setter @bot_data.setter
def bot_data(self, value: object) -> NoReturn: def bot_data(self, value: object) -> NoReturn:
@ -126,7 +142,7 @@ class CallbackContext:
) )
@property @property
def chat_data(self) -> Optional[Dict]: def chat_data(self) -> Optional[CD]:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update from the same chat id it will be the same ``dict``. update from the same chat id it will be the same ``dict``.
@ -136,7 +152,9 @@ class CallbackContext:
<https://github.com/python-telegram-bot/python-telegram-bot/wiki/ <https://github.com/python-telegram-bot/python-telegram-bot/wiki/
Storing-bot,-user-and-chat-related-data#chat-migration>`_. Storing-bot,-user-and-chat-related-data#chat-migration>`_.
""" """
return self._chat_data if self._chat_id_and_data:
return self._chat_id_and_data[1]
return None
@chat_data.setter @chat_data.setter
def chat_data(self, value: object) -> NoReturn: def chat_data(self, value: object) -> NoReturn:
@ -145,11 +163,13 @@ class CallbackContext:
) )
@property @property
def user_data(self) -> Optional[Dict]: def user_data(self) -> Optional[UD]:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update from the same user it will be the same ``dict``. update from the same user it will be the same ``dict``.
""" """
return self._user_data if self._user_id_and_data:
return self._user_id_and_data[1]
return None
@user_data.setter @user_data.setter
def user_data(self, value: object) -> NoReturn: def user_data(self, value: object) -> NoReturn:
@ -157,15 +177,32 @@ class CallbackContext:
"You can not assign a new value to user_data, see https://git.io/Jt6ic" "You can not assign a new value to user_data, see https://git.io/Jt6ic"
) )
def refresh_data(self) -> None:
"""If :attr:`dispatcher` uses persistence, calls
:meth:`telegram.ext.BasePersistence.refresh_bot_data` on :attr:`bot_data`,
:meth:`telegram.ext.BasePersistence.refresh_chat_data` on :attr:`chat_data` and
:meth:`telegram.ext.BasePersistence.refresh_user_data` on :attr:`user_data`, if
appropriate.
.. versionadded:: 13.6
"""
if self.dispatcher.persistence:
if self.dispatcher.persistence.store_bot_data:
self.dispatcher.persistence.refresh_bot_data(self.bot_data)
if self.dispatcher.persistence.store_chat_data and self._chat_id_and_data is not None:
self.dispatcher.persistence.refresh_chat_data(*self._chat_id_and_data)
if self.dispatcher.persistence.store_user_data and self._user_id_and_data is not None:
self.dispatcher.persistence.refresh_user_data(*self._user_id_and_data)
@classmethod @classmethod
def from_error( def from_error(
cls, cls: Type[CC],
update: object, update: object,
error: Exception, error: Exception,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
async_args: Union[List, Tuple] = None, async_args: Union[List, Tuple] = None,
async_kwargs: Dict[str, object] = None, async_kwargs: Dict[str, object] = None,
) -> 'CallbackContext': ) -> CC:
""" """
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
handlers. handlers.
@ -195,7 +232,7 @@ class CallbackContext:
return self return self
@classmethod @classmethod
def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CallbackContext': def from_update(cls: Type[CC], update: object, dispatcher: 'Dispatcher') -> CC:
""" """
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the
handlers. handlers.
@ -217,13 +254,19 @@ class CallbackContext:
user = update.effective_user user = update.effective_user
if chat: if chat:
self._chat_data = dispatcher.chat_data[chat.id] # pylint: disable=W0212 self._chat_id_and_data = (
chat.id,
dispatcher.chat_data[chat.id], # pylint: disable=W0212
)
if user: if user:
self._user_data = dispatcher.user_data[user.id] # pylint: disable=W0212 self._user_id_and_data = (
user.id,
dispatcher.user_data[user.id], # pylint: disable=W0212
)
return self return self
@classmethod @classmethod
def from_job(cls, job: 'Job', dispatcher: 'Dispatcher') -> 'CallbackContext': def from_job(cls: Type[CC], job: 'Job', dispatcher: 'Dispatcher') -> CC:
""" """
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to a Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to a
job callback. job callback.

View file

@ -35,14 +35,15 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class CallbackQueryHandler(Handler[Update]): class CallbackQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram callback queries. Optionally based on a regex. """Handler class to handle Telegram callback queries. Optionally based on a regex.
Read the documentation of the ``re`` module for more information. Read the documentation of the ``re`` module for more information.
@ -124,7 +125,7 @@ class CallbackQueryHandler(Handler[Update]):
def __init__( def __init__(
self, self,
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,
pattern: Union[str, Pattern] = None, pattern: Union[str, Pattern] = None,
@ -191,7 +192,7 @@ class CallbackQueryHandler(Handler[Update]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: Update, update: Update,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Union[bool, Match], check_result: Union[bool, Match],

View file

@ -17,19 +17,17 @@
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the ChatMemberHandler classes.""" """This module contains the ChatMemberHandler classes."""
from typing import ClassVar, TypeVar, Union, Callable, TYPE_CHECKING from typing import ClassVar, TypeVar, Union, Callable
from telegram import Update from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING:
from telegram.ext import CallbackContext
RT = TypeVar('RT') RT = TypeVar('RT')
class ChatMemberHandler(Handler[Update]): class ChatMemberHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a chat member update. """Handler class to handle Telegram updates that contain a chat member update.
.. versionadded:: 13.4 .. versionadded:: 13.4
@ -107,7 +105,7 @@ class ChatMemberHandler(Handler[Update]):
def __init__( def __init__(
self, self,
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
chat_member_types: int = MY_CHAT_MEMBER, chat_member_types: int = MY_CHAT_MEMBER,
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,

View file

@ -24,6 +24,7 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
RT = TypeVar('RT') RT = TypeVar('RT')
@ -31,7 +32,7 @@ if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import CallbackContext, Dispatcher
class ChosenInlineResultHandler(Handler[Update]): class ChosenInlineResultHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a chosen inline result. """Handler class to handle Telegram updates that contain a chosen inline result.
Note: Note:

View file

@ -27,15 +27,16 @@ from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import SLT from telegram.utils.types import SLT
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .utils.types import CCT
from .handler import Handler from .handler import Handler
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class CommandHandler(Handler[Update]): class CommandHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram commands. """Handler class to handle Telegram commands.
Commands are Telegram messages that start with ``/``, optionally followed by an ``@`` and the Commands are Telegram messages that start with ``/``, optionally followed by an ``@`` and the
@ -134,7 +135,7 @@ class CommandHandler(Handler[Update]):
def __init__( def __init__(
self, self,
command: SLT[str], command: SLT[str],
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
filters: BaseFilter = None, filters: BaseFilter = None,
allow_edited: bool = None, allow_edited: bool = None,
pass_args: bool = False, pass_args: bool = False,
@ -231,7 +232,7 @@ class CommandHandler(Handler[Update]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: Update, update: Update,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Tuple[List[str], Optional[bool]]]], check_result: Optional[Union[bool, Tuple[List[str], Optional[bool]]]],
@ -359,7 +360,7 @@ class PrefixHandler(CommandHandler):
self, self,
prefix: SLT[str], prefix: SLT[str],
command: SLT[str], command: SLT[str],
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
filters: BaseFilter = None, filters: BaseFilter = None,
pass_args: bool = False, pass_args: bool = False,
pass_update_queue: bool = False, pass_update_queue: bool = False,

View file

@ -0,0 +1,194 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2020
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
# pylint: disable=R0201
"""This module contains the auxiliary class ContextTypes."""
from typing import Type, Generic, overload, Dict # pylint: disable=W0611
from telegram.ext.callbackcontext import CallbackContext
from telegram.ext.utils.types import CCT, UD, CD, BD
class ContextTypes(Generic[CCT, UD, CD, BD]):
"""
Convenience class to gather customizable types of the :class:`telegram.ext.CallbackContext`
interface.
.. versionadded:: 13.6
Args:
context (:obj:`type`, optional): Determines the type of the ``context`` argument of all
(error-)handler callbacks and job callbacks. Must be a subclass of
:class:`telegram.ext.CallbackContext`. Defaults to
:class:`telegram.ext.CallbackContext`.
bot_data (:obj:`type`, optional): Determines the type of ``context.bot_data`` of all
(error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
instantiating without arguments.
chat_data (:obj:`type`, optional): Determines the type of ``context.chat_data`` of all
(error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
instantiating without arguments.
user_data (:obj:`type`, optional): Determines the type of ``context.user_data`` of all
(error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
instantiating without arguments.
"""
__slots__ = ('_context', '_bot_data', '_chat_data', '_user_data')
@overload
def __init__(
self: 'ContextTypes[CallbackContext, Dict, Dict, Dict]',
):
...
@overload
def __init__(self: 'ContextTypes[CCT, Dict, Dict, Dict]', context: Type[CCT]):
...
@overload
def __init__(self: 'ContextTypes[CallbackContext, UD, Dict, Dict]', bot_data: Type[UD]):
...
@overload
def __init__(self: 'ContextTypes[CallbackContext, Dict, CD, Dict]', chat_data: Type[CD]):
...
@overload
def __init__(self: 'ContextTypes[CallbackContext, Dict, Dict, BD]', user_data: Type[BD]):
...
@overload
def __init__(
self: 'ContextTypes[CCT, UD, Dict, Dict]', context: Type[CCT], bot_data: Type[UD]
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, Dict, CD, Dict]', context: Type[CCT], chat_data: Type[CD]
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, Dict, Dict, BD]', context: Type[CCT], user_data: Type[BD]
):
...
@overload
def __init__(
self: 'ContextTypes[CallbackContext, UD, CD, Dict]',
bot_data: Type[UD],
chat_data: Type[CD],
):
...
@overload
def __init__(
self: 'ContextTypes[CallbackContext, UD, Dict, BD]',
bot_data: Type[UD],
user_data: Type[BD],
):
...
@overload
def __init__(
self: 'ContextTypes[CallbackContext, Dict, CD, BD]',
chat_data: Type[CD],
user_data: Type[BD],
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, UD, CD, Dict]',
context: Type[CCT],
bot_data: Type[UD],
chat_data: Type[CD],
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, UD, Dict, BD]',
context: Type[CCT],
bot_data: Type[UD],
user_data: Type[BD],
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, Dict, CD, BD]',
context: Type[CCT],
chat_data: Type[CD],
user_data: Type[BD],
):
...
@overload
def __init__(
self: 'ContextTypes[CallbackContext, UD, CD, BD]',
bot_data: Type[UD],
chat_data: Type[CD],
user_data: Type[BD],
):
...
@overload
def __init__(
self: 'ContextTypes[CCT, UD, CD, BD]',
context: Type[CCT],
bot_data: Type[UD],
chat_data: Type[CD],
user_data: Type[BD],
):
...
def __init__( # type: ignore[no-untyped-def]
self,
context=CallbackContext,
bot_data=dict,
chat_data=dict,
user_data=dict,
):
if not issubclass(context, CallbackContext):
raise ValueError('context must be a subclass of CallbackContext.')
# We make all those only accessible via properties because we don't currently support
# changing this at runtime, so overriding the attributes doesn't make sense
self._context = context
self._bot_data = bot_data
self._chat_data = chat_data
self._user_data = user_data
@property
def context(self) -> Type[CCT]:
return self._context
@property
def bot_data(self) -> Type[BD]:
return self._bot_data
@property
def chat_data(self) -> Type[CD]:
return self._chat_data
@property
def user_data(self) -> Type[UD]:
return self._user_data

View file

@ -38,6 +38,7 @@ from telegram.ext import (
) )
from telegram.ext.utils.promise import Promise from telegram.ext.utils.promise import Promise
from telegram.utils.types import ConversationDict from telegram.utils.types import ConversationDict
from telegram.ext.utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import Dispatcher, Job from telegram.ext import Dispatcher, Job
@ -61,7 +62,7 @@ class _ConversationTimeoutContext:
self.callback_context = callback_context self.callback_context = callback_context
class ConversationHandler(Handler[Update]): class ConversationHandler(Handler[Update, CCT]):
""" """
A handler to hold a conversation with a single or multiple users through Telegram updates by A handler to hold a conversation with a single or multiple users through Telegram updates by
managing four collections of other handlers. managing four collections of other handlers.
@ -215,9 +216,9 @@ class ConversationHandler(Handler[Update]):
# pylint: disable=W0231 # pylint: disable=W0231
def __init__( def __init__(
self, self,
entry_points: List[Handler], entry_points: List[Handler[Update, CCT]],
states: Dict[object, List[Handler]], states: Dict[object, List[Handler[Update, CCT]]],
fallbacks: List[Handler], fallbacks: List[Handler[Update, CCT]],
allow_reentry: bool = False, allow_reentry: bool = False,
per_chat: bool = True, per_chat: bool = True,
per_user: bool = True, per_user: bool = True,

View file

@ -17,7 +17,6 @@
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the DictPersistence class.""" """This module contains the DictPersistence class."""
from copy import deepcopy
from typing import DefaultDict, Dict, Optional, Tuple from typing import DefaultDict, Dict, Optional, Tuple
from collections import defaultdict from collections import defaultdict
@ -37,7 +36,7 @@ except ImportError:
class DictPersistence(BasePersistence): class DictPersistence(BasePersistence):
"""Using python's dicts and json for making your bot persistent. """Using Python's :obj:`dict` and ``json`` for making your bot persistent.
Note: Note:
This class does *not* implement a :meth:`flush` method, meaning that data managed by This class does *not* implement a :meth:`flush` method, meaning that data managed by
@ -202,7 +201,7 @@ class DictPersistence(BasePersistence):
pass pass
else: else:
self._user_data = defaultdict(dict) self._user_data = defaultdict(dict)
return deepcopy(self.user_data) # type: ignore[arg-type] return self.user_data # type: ignore[return-value]
def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]:
"""Returns the chat_data created from the ``chat_data_json`` or an empty """Returns the chat_data created from the ``chat_data_json`` or an empty
@ -215,7 +214,7 @@ class DictPersistence(BasePersistence):
pass pass
else: else:
self._chat_data = defaultdict(dict) self._chat_data = defaultdict(dict)
return deepcopy(self.chat_data) # type: ignore[arg-type] return self.chat_data # type: ignore[return-value]
def get_bot_data(self) -> Dict[object, object]: def get_bot_data(self) -> Dict[object, object]:
"""Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`. """Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`.
@ -227,7 +226,7 @@ class DictPersistence(BasePersistence):
pass pass
else: else:
self._bot_data = {} self._bot_data = {}
return deepcopy(self.bot_data) # type: ignore[arg-type] return self.bot_data # type: ignore[return-value]
def get_conversations(self, name: str) -> ConversationDict: def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations created from the ``conversations_json`` or an empty """Returns the conversations created from the ``conversations_json`` or an empty
@ -264,7 +263,7 @@ class DictPersistence(BasePersistence):
Args: Args:
user_id (:obj:`int`): The user the data might have been changed for. user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
""" """
if self._user_data is None: if self._user_data is None:
self._user_data = defaultdict(dict) self._user_data = defaultdict(dict)
@ -278,7 +277,7 @@ class DictPersistence(BasePersistence):
Args: Args:
chat_id (:obj:`int`): The chat the data might have been changed for. chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
""" """
if self._chat_data is None: if self._chat_data is None:
self._chat_data = defaultdict(dict) self._chat_data = defaultdict(dict)
@ -295,5 +294,26 @@ class DictPersistence(BasePersistence):
""" """
if self._bot_data == data: if self._bot_data == data:
return return
self._bot_data = data.copy() self._bot_data = data
self._bot_data_json = None self._bot_data_json = None
def refresh_user_data(self, user_id: int, user_data: Dict) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data`
"""
def refresh_chat_data(self, chat_id: int, chat_data: Dict) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data`
"""
def refresh_bot_data(self, bot_data: Dict) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data`
"""

View file

@ -26,16 +26,30 @@ from functools import wraps
from queue import Empty, Queue from queue import Empty, Queue
from threading import BoundedSemaphore, Event, Lock, Thread, current_thread from threading import BoundedSemaphore, Event, Lock, Thread, current_thread
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Callable, DefaultDict, Dict, List, Optional, Union, Set from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Set,
Union,
Generic,
TypeVar,
overload,
cast,
DefaultDict,
)
from uuid import uuid4 from uuid import uuid4
from telegram import TelegramError, Update from telegram import TelegramError, Update
from telegram.ext import BasePersistence from telegram.ext import BasePersistence, ContextTypes
from telegram.ext.callbackcontext import CallbackContext from telegram.ext.callbackcontext import CallbackContext
from telegram.ext.handler import Handler from telegram.ext.handler import Handler
from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated
from telegram.ext.utils.promise import Promise from telegram.ext.utils.promise import Promise
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from telegram.ext.utils.types import CCT, UD, CD, BD
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram import Bot from telegram import Bot
@ -43,6 +57,8 @@ if TYPE_CHECKING:
DEFAULT_GROUP: int = 0 DEFAULT_GROUP: int = 0
UT = TypeVar('UT')
def run_async( def run_async(
func: Callable[[Update, CallbackContext], object] func: Callable[[Update, CallbackContext], object]
@ -105,7 +121,7 @@ class DispatcherHandlerStop(Exception):
self.state = state self.state = state
class Dispatcher: class Dispatcher(Generic[CCT, UD, CD, BD]):
"""This class dispatches all kinds of updates to its registered handlers. """This class dispatches all kinds of updates to its registered handlers.
Args: Args:
@ -120,6 +136,12 @@ class Dispatcher:
use_context (:obj:`bool`, optional): If set to :obj:`True` uses the context based callback use_context (:obj:`bool`, optional): If set to :obj:`True` uses the context based callback
API (ignored if `dispatcher` argument is used). Defaults to :obj:`True`. API (ignored if `dispatcher` argument is used). Defaults to :obj:`True`.
**New users**: set this to :obj:`True`. **New users**: set this to :obj:`True`.
context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
of :class:`telegram.ext.ContextTypes` to customize the types used in the
``context`` interface. If not passed, the defaults documented in
:class:`telegram.ext.ContextTypes` will be used.
.. versionadded:: 13.6
Attributes: Attributes:
bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers.
@ -133,6 +155,10 @@ class Dispatcher:
bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts. store data that should be persistent over restarts.
context_types (:class:`telegram.ext.ContextTypes`): Container for the types used
in the ``context`` interface.
.. versionadded:: 13.6
""" """
@ -158,6 +184,7 @@ class Dispatcher:
'bot', 'bot',
'__dict__', '__dict__',
'__weakref__', '__weakref__',
'context_types',
) )
__singleton_lock = Lock() __singleton_lock = Lock()
@ -165,6 +192,33 @@ class Dispatcher:
__singleton = None __singleton = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@overload
def __init__(
self: 'Dispatcher[CallbackContext[Dict, Dict, Dict], Dict, Dict, Dict]',
bot: 'Bot',
update_queue: Queue,
workers: int = 4,
exception_event: Event = None,
job_queue: 'JobQueue' = None,
persistence: BasePersistence = None,
use_context: bool = True,
):
...
@overload
def __init__(
self: 'Dispatcher[CCT, UD, CD, BD]',
bot: 'Bot',
update_queue: Queue,
workers: int = 4,
exception_event: Event = None,
job_queue: 'JobQueue' = None,
persistence: BasePersistence = None,
use_context: bool = True,
context_types: ContextTypes[CCT, UD, CD, BD] = None,
):
...
def __init__( def __init__(
self, self,
bot: 'Bot', bot: 'Bot',
@ -174,12 +228,14 @@ class Dispatcher:
job_queue: 'JobQueue' = None, job_queue: 'JobQueue' = None,
persistence: BasePersistence = None, persistence: BasePersistence = None,
use_context: bool = True, use_context: bool = True,
context_types: ContextTypes[CCT, UD, CD, BD] = None,
): ):
self.bot = bot self.bot = bot
self.update_queue = update_queue self.update_queue = update_queue
self.job_queue = job_queue self.job_queue = job_queue
self.workers = workers self.workers = workers
self.use_context = use_context self.use_context = use_context
self.context_types = cast(ContextTypes[CCT, UD, CD, BD], context_types or ContextTypes())
if not use_context: if not use_context:
warnings.warn( warnings.warn(
@ -193,9 +249,9 @@ class Dispatcher:
'Asynchronous callbacks can not be processed without at least one worker thread.' 'Asynchronous callbacks can not be processed without at least one worker thread.'
) )
self.user_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict) self.user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data)
self.chat_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict) self.chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data)
self.bot_data = {} self.bot_data = self.context_types.bot_data()
self.persistence: Optional[BasePersistence] = None self.persistence: Optional[BasePersistence] = None
self._update_persistence_lock = Lock() self._update_persistence_lock = Lock()
if persistence: if persistence:
@ -213,8 +269,11 @@ class Dispatcher:
raise ValueError("chat_data must be of type defaultdict") raise ValueError("chat_data must be of type defaultdict")
if self.persistence.store_bot_data: if self.persistence.store_bot_data:
self.bot_data = self.persistence.get_bot_data() self.bot_data = self.persistence.get_bot_data()
if not isinstance(self.bot_data, dict): if not isinstance(self.bot_data, self.context_types.bot_data):
raise ValueError("bot_data must be of type dict") raise ValueError(
f"bot_data must be of type {self.context_types.bot_data.__name__}"
)
else: else:
self.persistence = None self.persistence = None
@ -477,7 +536,8 @@ class Dispatcher:
check = handler.check_update(update) check = handler.check_update(update)
if check is not None and check is not False: if check is not None and check is not False:
if not context and self.use_context: if not context and self.use_context:
context = CallbackContext.from_update(update, self) context = self.context_types.context.from_update(update, self)
context.refresh_data()
handled = True handled = True
sync_modes.append(handler.run_async) sync_modes.append(handler.run_async)
handler.handle_update(update, self, check, context) handler.handle_update(update, self, check, context)
@ -510,7 +570,7 @@ class Dispatcher:
if not handled_only_async: if not handled_only_async:
self.update_persistence(update=update) self.update_persistence(update=update)
def add_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: def add_handler(self, handler: Handler[UT, CCT], group: int = DEFAULT_GROUP) -> None:
"""Register a handler. """Register a handler.
TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of
@ -542,14 +602,22 @@ class Dispatcher:
raise TypeError(f'handler is not an instance of {Handler.__name__}') raise TypeError(f'handler is not an instance of {Handler.__name__}')
if not isinstance(group, int): if not isinstance(group, int):
raise TypeError('group is not int') raise TypeError('group is not int')
if isinstance(handler, ConversationHandler) and handler.persistent and handler.name: # For some reason MyPy infers the type of handler is <nothing> here,
# so for now we just ignore all the errors
if (
isinstance(handler, ConversationHandler)
and handler.persistent # type: ignore[attr-defined]
and handler.name # type: ignore[attr-defined]
):
if not self.persistence: if not self.persistence:
raise ValueError( raise ValueError(
f"ConversationHandler {handler.name} can not be persistent if dispatcher has " f"ConversationHandler {handler.name} " # type: ignore[attr-defined]
f"no persistence" f"can not be persistent if dispatcher has no persistence"
)
handler.persistence = self.persistence # type: ignore[attr-defined]
handler.conversations = ( # type: ignore[attr-defined]
self.persistence.get_conversations(handler.name) # type: ignore[attr-defined]
) )
handler.persistence = self.persistence
handler.conversations = self.persistence.get_conversations(handler.name)
if group not in self.handlers: if group not in self.handlers:
self.handlers[group] = [] self.handlers[group] = []
@ -643,7 +711,7 @@ class Dispatcher:
def add_error_handler( def add_error_handler(
self, self,
callback: Callable[[object, CallbackContext], None], callback: Callable[[object, CCT], None],
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, # pylint: disable=W0621 run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, # pylint: disable=W0621
) -> None: ) -> None:
"""Registers an error handler in the Dispatcher. This handler will receive every error """Registers an error handler in the Dispatcher. This handler will receive every error
@ -678,7 +746,7 @@ class Dispatcher:
self.error_handlers[callback] = run_async self.error_handlers[callback] = run_async
def remove_error_handler(self, callback: Callable[[object, CallbackContext], None]) -> None: def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None:
"""Removes an error handler. """Removes an error handler.
Args: Args:
@ -705,7 +773,7 @@ class Dispatcher:
if self.error_handlers: if self.error_handlers:
for callback, run_async in self.error_handlers.items(): # pylint: disable=W0621 for callback, run_async in self.error_handlers.items(): # pylint: disable=W0621
if self.use_context: if self.use_context:
context = CallbackContext.from_error( context = self.context_types.context.from_error(
update, error, self, async_args=async_args, async_kwargs=async_kwargs update, error, self, async_args=async_args, async_kwargs=async_kwargs
) )
if run_async: if run_async:

View file

@ -26,15 +26,16 @@ from telegram.utils.deprecate import set_new_attribute_deprecated
from telegram import Update from telegram import Update
from telegram.ext.utils.promise import Promise from telegram.ext.utils.promise import Promise
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from telegram.ext.utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
UT = TypeVar('UT') UT = TypeVar('UT')
class Handler(Generic[UT], ABC): class Handler(Generic[UT, CCT], ABC):
"""The base class for all update handlers. Create custom handlers by inheriting from it. """The base class for all update handlers. Create custom handlers by inheriting from it.
Note: Note:
@ -115,7 +116,7 @@ class Handler(Generic[UT], ABC):
def __init__( def __init__(
self, self,
callback: Callable[[UT, 'CallbackContext'], RT], callback: Callable[[UT, CCT], RT],
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,
pass_user_data: bool = False, pass_user_data: bool = False,
@ -165,7 +166,7 @@ class Handler(Generic[UT], ABC):
update: UT, update: UT,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: object, check_result: object,
context: 'CallbackContext' = None, context: CCT = None,
) -> Union[RT, Promise]: ) -> Union[RT, Promise]:
""" """
This method is called if it was determined that an update should indeed This method is called if it was determined that an update should indeed
@ -205,7 +206,7 @@ class Handler(Generic[UT], ABC):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: UT, update: UT,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Any, check_result: Any,

View file

@ -35,14 +35,15 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class InlineQueryHandler(Handler[Update]): class InlineQueryHandler(Handler[Update, CCT]):
""" """
Handler class to handle Telegram inline queries. Optionally based on a regex. Read the Handler class to handle Telegram inline queries. Optionally based on a regex. Read the
documentation of the ``re`` module for more information. documentation of the ``re`` module for more information.
@ -133,7 +134,7 @@ class InlineQueryHandler(Handler[Update]):
def __init__( def __init__(
self, self,
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,
pattern: Union[str, Pattern] = None, pattern: Union[str, Pattern] = None,
@ -207,7 +208,7 @@ class InlineQueryHandler(Handler[Update]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: Update, update: Update,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Match]], check_result: Optional[Union[bool, Match]],

View file

@ -72,7 +72,7 @@ class JobQueue:
def _build_args(self, job: 'Job') -> List[Union[CallbackContext, 'Bot', 'Job']]: def _build_args(self, job: 'Job') -> List[Union[CallbackContext, 'Bot', 'Job']]:
if self._dispatcher.use_context: if self._dispatcher.use_context:
return [CallbackContext.from_job(job, self._dispatcher)] return [self._dispatcher.context_types.context.from_job(job, self._dispatcher)]
return [self._dispatcher.bot, job] return [self._dispatcher.bot, job]
def _tz_now(self) -> datetime.datetime: def _tz_now(self) -> datetime.datetime:
@ -585,7 +585,7 @@ class Job:
"""Executes the callback function independently of the jobs schedule.""" """Executes the callback function independently of the jobs schedule."""
try: try:
if dispatcher.use_context: if dispatcher.use_context:
self.callback(CallbackContext.from_job(self, dispatcher)) self.callback(dispatcher.context_types.context.from_job(self, dispatcher))
else: else:
self.callback(dispatcher.bot, self) # type: ignore[arg-type,call-arg] self.callback(dispatcher.bot, self) # type: ignore[arg-type,call-arg]
except Exception as exc: except Exception as exc:

View file

@ -27,14 +27,15 @@ from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class MessageHandler(Handler[Update]): class MessageHandler(Handler[Update, CCT]):
"""Handler class to handle telegram messages. They might contain text, media or status updates. """Handler class to handle telegram messages. They might contain text, media or status updates.
Note: Note:
@ -125,7 +126,7 @@ class MessageHandler(Handler[Update]):
def __init__( def __init__(
self, self,
filters: BaseFilter, filters: BaseFilter,
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,
pass_user_data: bool = False, pass_user_data: bool = False,
@ -197,7 +198,7 @@ class MessageHandler(Handler[Update]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: Update, update: Update,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Dict[str, object]]], check_result: Optional[Union[bool, Dict[str, object]]],

View file

@ -19,14 +19,23 @@
"""This module contains the PicklePersistence class.""" """This module contains the PicklePersistence class."""
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from typing import (
from typing import Any, DefaultDict, Dict, Optional, Tuple Any,
Dict,
Optional,
Tuple,
overload,
cast,
DefaultDict,
)
from telegram.ext import BasePersistence from telegram.ext import BasePersistence
from telegram.utils.types import ConversationDict from telegram.utils.types import ConversationDict # pylint: disable=W0611
from .utils.types import UD, CD, BD
from .contexttypes import ContextTypes
class PicklePersistence(BasePersistence): class PicklePersistence(BasePersistence[UD, CD, BD]):
"""Using python's builtin pickle for making you bot persistent. """Using python's builtin pickle for making you bot persistent.
Warning: Warning:
@ -54,6 +63,12 @@ class PicklePersistence(BasePersistence):
:meth:`flush` is called and keep data in memory until that happens. When :meth:`flush` is called and keep data in memory until that happens. When
:obj:`False` will store data on any transaction *and* on call to :meth:`flush`. :obj:`False` will store data on any transaction *and* on call to :meth:`flush`.
Default is :obj:`False`. Default is :obj:`False`.
context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
of :class:`telegram.ext.ContextTypes` to customize the types used in the
``context`` interface. If not passed, the defaults documented in
:class:`telegram.ext.ContextTypes` will be used.
.. versionadded:: 13.6
Attributes: Attributes:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file` filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
@ -71,6 +86,10 @@ class PicklePersistence(BasePersistence):
:meth:`flush` is called and keep data in memory until that happens. When :meth:`flush` is called and keep data in memory until that happens. When
:obj:`False` will store data on any transaction *and* on call to :meth:`flush`. :obj:`False` will store data on any transaction *and* on call to :meth:`flush`.
Default is :obj:`False`. Default is :obj:`False`.
context_types (:class:`telegram.ext.ContextTypes`): Container for the types used
in the ``context`` interface.
.. versionadded:: 13.6
""" """
__slots__ = ( __slots__ = (
@ -81,8 +100,34 @@ class PicklePersistence(BasePersistence):
'chat_data', 'chat_data',
'bot_data', 'bot_data',
'conversations', 'conversations',
'context_types',
) )
@overload
def __init__(
self: 'PicklePersistence[Dict, Dict, Dict]',
filename: str,
store_user_data: bool = True,
store_chat_data: bool = True,
store_bot_data: bool = True,
single_file: bool = True,
on_flush: bool = False,
):
...
@overload
def __init__(
self: 'PicklePersistence[UD, CD, BD]',
filename: str,
store_user_data: bool = True,
store_chat_data: bool = True,
store_bot_data: bool = True,
single_file: bool = True,
on_flush: bool = False,
context_types: ContextTypes[Any, UD, CD, BD] = None,
):
...
def __init__( def __init__(
self, self,
filename: str, filename: str,
@ -91,6 +136,7 @@ class PicklePersistence(BasePersistence):
store_bot_data: bool = True, store_bot_data: bool = True,
single_file: bool = True, single_file: bool = True,
on_flush: bool = False, on_flush: bool = False,
context_types: ContextTypes[Any, UD, CD, BD] = None,
): ):
super().__init__( super().__init__(
store_user_data=store_user_data, store_user_data=store_user_data,
@ -100,26 +146,27 @@ class PicklePersistence(BasePersistence):
self.filename = filename self.filename = filename
self.single_file = single_file self.single_file = single_file
self.on_flush = on_flush self.on_flush = on_flush
self.user_data: Optional[DefaultDict[int, Dict]] = None self.user_data: Optional[DefaultDict[int, UD]] = None
self.chat_data: Optional[DefaultDict[int, Dict]] = None self.chat_data: Optional[DefaultDict[int, CD]] = None
self.bot_data: Optional[Dict] = None self.bot_data: Optional[BD] = None
self.conversations: Optional[Dict[str, Dict[Tuple, object]]] = None self.conversations: Optional[Dict[str, Dict[Tuple, object]]] = None
self.context_types = cast(ContextTypes[Any, UD, CD, BD], context_types or ContextTypes())
def _load_singlefile(self) -> None: def _load_singlefile(self) -> None:
try: try:
filename = self.filename filename = self.filename
with open(self.filename, "rb") as file: with open(self.filename, "rb") as file:
data = pickle.load(file) data = pickle.load(file)
self.user_data = defaultdict(dict, data['user_data']) self.user_data = defaultdict(self.context_types.user_data, data['user_data'])
self.chat_data = defaultdict(dict, data['chat_data']) self.chat_data = defaultdict(self.context_types.chat_data, data['chat_data'])
# For backwards compatibility with files not containing bot data # For backwards compatibility with files not containing bot data
self.bot_data = data.get('bot_data', {}) self.bot_data = data.get('bot_data', self.context_types.bot_data())
self.conversations = data['conversations'] self.conversations = data['conversations']
except OSError: except OSError:
self.conversations = {} self.conversations = {}
self.user_data = defaultdict(dict) self.user_data = defaultdict(self.context_types.user_data)
self.chat_data = defaultdict(dict) self.chat_data = defaultdict(self.context_types.chat_data)
self.bot_data = {} self.bot_data = self.context_types.bot_data()
except pickle.UnpicklingError as exc: except pickle.UnpicklingError as exc:
raise TypeError(f"File {filename} does not contain valid pickle data") from exc raise TypeError(f"File {filename} does not contain valid pickle data") from exc
except Exception as exc: except Exception as exc:
@ -152,11 +199,11 @@ class PicklePersistence(BasePersistence):
with open(filename, "wb") as file: with open(filename, "wb") as file:
pickle.dump(data, file) pickle.dump(data, file)
def get_user_data(self) -> DefaultDict[int, Dict[object, object]]: def get_user_data(self) -> DefaultDict[int, UD]:
"""Returns the user_data from the pickle file if it exists or an empty :obj:`defaultdict`. """Returns the user_data from the pickle file if it exists or an empty :obj:`defaultdict`.
Returns: Returns:
:obj:`defaultdict`: The restored user data. DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data.
""" """
if self.user_data: if self.user_data:
pass pass
@ -164,19 +211,19 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_user_data" filename = f"{self.filename}_user_data"
data = self._load_file(filename) data = self._load_file(filename)
if not data: if not data:
data = defaultdict(dict) data = defaultdict(self.context_types.user_data)
else: else:
data = defaultdict(dict, data) data = defaultdict(self.context_types.user_data, data)
self.user_data = data self.user_data = data
else: else:
self._load_singlefile() self._load_singlefile()
return deepcopy(self.user_data) # type: ignore[arg-type] return self.user_data # type: ignore[return-value]
def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]: def get_chat_data(self) -> DefaultDict[int, CD]:
"""Returns the chat_data from the pickle file if it exists or an empty :obj:`defaultdict`. """Returns the chat_data from the pickle file if it exists or an empty :obj:`defaultdict`.
Returns: Returns:
:obj:`defaultdict`: The restored chat data. DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data.
""" """
if self.chat_data: if self.chat_data:
pass pass
@ -184,19 +231,20 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_chat_data" filename = f"{self.filename}_chat_data"
data = self._load_file(filename) data = self._load_file(filename)
if not data: if not data:
data = defaultdict(dict) data = defaultdict(self.context_types.chat_data)
else: else:
data = defaultdict(dict, data) data = defaultdict(self.context_types.chat_data, data)
self.chat_data = data self.chat_data = data
else: else:
self._load_singlefile() self._load_singlefile()
return deepcopy(self.chat_data) # type: ignore[arg-type] return self.chat_data # type: ignore[return-value]
def get_bot_data(self) -> Dict[object, object]: def get_bot_data(self) -> BD:
"""Returns the bot_data from the pickle file if it exists or an empty :obj:`dict`. """Returns the bot_data from the pickle file if it exists or an empty object of type
:class:`telegram.ext.utils.types.BD`.
Returns: Returns:
:obj:`dict`: The restored bot data. :class:`telegram.ext.utils.types.BD`: The restored bot data.
""" """
if self.bot_data: if self.bot_data:
pass pass
@ -204,11 +252,11 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_bot_data" filename = f"{self.filename}_bot_data"
data = self._load_file(filename) data = self._load_file(filename)
if not data: if not data:
data = {} data = self.context_types.bot_data()
self.bot_data = data self.bot_data = data
else: else:
self._load_singlefile() self._load_singlefile()
return deepcopy(self.bot_data) # type: ignore[arg-type] return self.bot_data # type: ignore[return-value]
def get_conversations(self, name: str) -> ConversationDict: def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations from the pickle file if it exsists or an empty dict. """Returns the conversations from the pickle file if it exsists or an empty dict.
@ -254,15 +302,16 @@ class PicklePersistence(BasePersistence):
else: else:
self._dump_singlefile() self._dump_singlefile()
def update_user_data(self, user_id: int, data: Dict) -> None: def update_user_data(self, user_id: int, data: UD) -> None:
"""Will update the user_data and depending on :attr:`on_flush` save the pickle file. """Will update the user_data and depending on :attr:`on_flush` save the pickle file.
Args: Args:
user_id (:obj:`int`): The user the data might have been changed for. user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. data (:class:`telegram.ext.utils.types.UD`): The
:attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
""" """
if self.user_data is None: if self.user_data is None:
self.user_data = defaultdict(dict) self.user_data = defaultdict(self.context_types.user_data)
if self.user_data.get(user_id) == data: if self.user_data.get(user_id) == data:
return return
self.user_data[user_id] = data self.user_data[user_id] = data
@ -273,15 +322,16 @@ class PicklePersistence(BasePersistence):
else: else:
self._dump_singlefile() self._dump_singlefile()
def update_chat_data(self, chat_id: int, data: Dict) -> None: def update_chat_data(self, chat_id: int, data: CD) -> None:
"""Will update the chat_data and depending on :attr:`on_flush` save the pickle file. """Will update the chat_data and depending on :attr:`on_flush` save the pickle file.
Args: Args:
chat_id (:obj:`int`): The chat the data might have been changed for. chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. data (:class:`telegram.ext.utils.types.CD`): The
:attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
""" """
if self.chat_data is None: if self.chat_data is None:
self.chat_data = defaultdict(dict) self.chat_data = defaultdict(self.context_types.chat_data)
if self.chat_data.get(chat_id) == data: if self.chat_data.get(chat_id) == data:
return return
self.chat_data[chat_id] = data self.chat_data[chat_id] = data
@ -292,15 +342,16 @@ class PicklePersistence(BasePersistence):
else: else:
self._dump_singlefile() self._dump_singlefile()
def update_bot_data(self, data: Dict) -> None: def update_bot_data(self, data: BD) -> None:
"""Will update the bot_data and depending on :attr:`on_flush` save the pickle file. """Will update the bot_data and depending on :attr:`on_flush` save the pickle file.
Args: Args:
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data`. data (:class:`telegram.ext.utils.types.BD`): The
:attr:`telegram.ext.dispatcher.bot_data`.
""" """
if self.bot_data == data: if self.bot_data == data:
return return
self.bot_data = data.copy() self.bot_data = data
if not self.on_flush: if not self.on_flush:
if not self.single_file: if not self.single_file:
filename = f"{self.filename}_bot_data" filename = f"{self.filename}_bot_data"
@ -308,6 +359,27 @@ class PicklePersistence(BasePersistence):
else: else:
self._dump_singlefile() self._dump_singlefile()
def refresh_user_data(self, user_id: int, user_data: UD) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data`
"""
def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data`
"""
def refresh_bot_data(self, bot_data: BD) -> None:
"""Does nothing.
.. versionadded:: 13.6
.. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data`
"""
def flush(self) -> None: def flush(self) -> None:
"""Will save all data in memory to pickle file(s).""" """Will save all data in memory to pickle file(s)."""
if self.single_file: if self.single_file:

View file

@ -22,9 +22,10 @@
from telegram import Update from telegram import Update
from .handler import Handler from .handler import Handler
from .utils.types import CCT
class PollAnswerHandler(Handler[Update]): class PollAnswerHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a poll answer. """Handler class to handle Telegram updates that contain a poll answer.
Note: Note:

View file

@ -22,9 +22,10 @@
from telegram import Update from telegram import Update
from .handler import Handler from .handler import Handler
from .utils.types import CCT
class PollHandler(Handler[Update]): class PollHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a poll. """Handler class to handle Telegram updates that contain a poll.
Note: Note:

View file

@ -22,9 +22,10 @@
from telegram import Update from telegram import Update
from .handler import Handler from .handler import Handler
from .utils.types import CCT
class PreCheckoutQueryHandler(Handler[Update]): class PreCheckoutQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram PreCheckout callback queries. """Handler class to handle Telegram PreCheckout callback queries.
Note: Note:

View file

@ -26,9 +26,10 @@ from telegram import Update
from telegram.ext import Filters, MessageHandler from telegram.ext import Filters, MessageHandler
from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from telegram.ext.utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
@ -113,7 +114,7 @@ class RegexHandler(MessageHandler):
def __init__( def __init__(
self, self,
pattern: Union[str, Pattern], pattern: Union[str, Pattern],
callback: Callable[[Update, 'CallbackContext'], RT], callback: Callable[[Update, CCT], RT],
pass_groups: bool = False, pass_groups: bool = False,
pass_groupdict: bool = False, pass_groupdict: bool = False,
pass_update_queue: bool = False, pass_update_queue: bool = False,

View file

@ -21,9 +21,10 @@
from telegram import Update from telegram import Update
from .handler import Handler from .handler import Handler
from .utils.types import CCT
class ShippingQueryHandler(Handler[Update]): class ShippingQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram shipping callback queries. """Handler class to handle Telegram shipping callback queries.
Note: Note:

View file

@ -23,14 +23,15 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, Union
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class StringCommandHandler(Handler[str]): class StringCommandHandler(Handler[str, CCT]):
"""Handler class to handle string commands. Commands are string updates that start with ``/``. """Handler class to handle string commands. Commands are string updates that start with ``/``.
The handler will add a ``list`` to the The handler will add a ``list`` to the
:class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings, :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings,
@ -90,7 +91,7 @@ class StringCommandHandler(Handler[str]):
def __init__( def __init__(
self, self,
command: str, command: str,
callback: Callable[[str, 'CallbackContext'], RT], callback: Callable[[str, CCT], RT],
pass_args: bool = False, pass_args: bool = False,
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,
@ -137,7 +138,7 @@ class StringCommandHandler(Handler[str]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: str, update: str,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Optional[List[str]], check_result: Optional[List[str]],

View file

@ -24,14 +24,15 @@ from typing import TYPE_CHECKING, Callable, Dict, Match, Optional, Pattern, Type
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher from telegram.ext import Dispatcher
RT = TypeVar('RT') RT = TypeVar('RT')
class StringRegexHandler(Handler[str]): class StringRegexHandler(Handler[str, CCT]):
"""Handler class to handle string updates based on a regex which checks the update content. """Handler class to handle string updates based on a regex which checks the update content.
Read the documentation of the ``re`` module for more information. The ``re.match`` function is Read the documentation of the ``re`` module for more information. The ``re.match`` function is
@ -96,7 +97,7 @@ class StringRegexHandler(Handler[str]):
def __init__( def __init__(
self, self,
pattern: Union[str, Pattern], pattern: Union[str, Pattern],
callback: Callable[[str, 'CallbackContext'], RT], callback: Callable[[str, CCT], RT],
pass_groups: bool = False, pass_groups: bool = False,
pass_groupdict: bool = False, pass_groupdict: bool = False,
pass_update_queue: bool = False, pass_update_queue: bool = False,
@ -153,7 +154,7 @@ class StringRegexHandler(Handler[str]):
def collect_additional_context( def collect_additional_context(
self, self,
context: 'CallbackContext', context: CCT,
update: str, update: str,
dispatcher: 'Dispatcher', dispatcher: 'Dispatcher',
check_result: Optional[Match], check_result: Optional[Match],

View file

@ -18,19 +18,17 @@
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the TypeHandler class.""" """This module contains the TypeHandler class."""
from typing import TYPE_CHECKING, Callable, Type, TypeVar, Union from typing import Callable, Type, TypeVar, Union
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler from .handler import Handler
from .utils.types import CCT
if TYPE_CHECKING:
from telegram.ext import CallbackContext
RT = TypeVar('RT') RT = TypeVar('RT')
UT = TypeVar('UT') UT = TypeVar('UT')
class TypeHandler(Handler[UT]): class TypeHandler(Handler[UT, CCT]):
"""Handler class to handle updates of custom types. """Handler class to handle updates of custom types.
Warning: Warning:
@ -80,7 +78,7 @@ class TypeHandler(Handler[UT]):
def __init__( def __init__(
self, self,
type: Type[UT], # pylint: disable=W0622 type: Type[UT], # pylint: disable=W0622
callback: Callable[[UT, 'CallbackContext'], RT], callback: Callable[[UT, CCT], RT],
strict: bool = False, strict: bool = False,
pass_update_queue: bool = False, pass_update_queue: bool = False,
pass_job_queue: bool = False, pass_job_queue: bool = False,

View file

@ -25,21 +25,34 @@ from queue import Queue
from signal import SIGABRT, SIGINT, SIGTERM, signal from signal import SIGABRT, SIGINT, SIGTERM, signal
from threading import Event, Lock, Thread, current_thread from threading import Event, Lock, Thread, current_thread
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, no_type_check from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
no_type_check,
Generic,
overload,
)
from telegram import Bot, TelegramError from telegram import Bot, TelegramError
from telegram.error import InvalidToken, RetryAfter, TimedOut, Unauthorized from telegram.error import InvalidToken, RetryAfter, TimedOut, Unauthorized
from telegram.ext import Dispatcher, JobQueue from telegram.ext import Dispatcher, JobQueue, ContextTypes
from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated
from telegram.utils.helpers import get_signal_name from telegram.utils.helpers import get_signal_name
from telegram.utils.request import Request from telegram.utils.request import Request
from telegram.ext.utils.types import CCT, UD, CD, BD
from telegram.ext.utils.webhookhandler import WebhookAppClass, WebhookServer from telegram.ext.utils.webhookhandler import WebhookAppClass, WebhookServer
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram.ext import BasePersistence, Defaults from telegram.ext import BasePersistence, Defaults, CallbackContext
class Updater: class Updater(Generic[CCT, UD, CD, BD]):
""" """
This class, which employs the :class:`telegram.ext.Dispatcher`, provides a frontend to This class, which employs the :class:`telegram.ext.Dispatcher`, provides a frontend to
:class:`telegram.Bot` to the programmer, so they can focus on coding the bot. Its purpose is to :class:`telegram.Bot` to the programmer, so they can focus on coding the bot. Its purpose is to
@ -85,6 +98,12 @@ class Updater:
used). used).
defaults (:class:`telegram.ext.Defaults`, optional): An object containing default values to defaults (:class:`telegram.ext.Defaults`, optional): An object containing default values to
be used if not set explicitly in the bot methods. be used if not set explicitly in the bot methods.
context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
of :class:`telegram.ext.ContextTypes` to customize the types used in the
``context`` interface. If not passed, the defaults documented in
:class:`telegram.ext.ContextTypes` will be used.
.. versionadded:: 13.6
Raises: Raises:
ValueError: If both :attr:`token` and :attr:`bot` are passed or none of them. ValueError: If both :attr:`token` and :attr:`bot` are passed or none of them.
@ -124,7 +143,52 @@ class Updater:
'__dict__', '__dict__',
) )
@overload
def __init__( def __init__(
self: 'Updater[CallbackContext, dict, dict, dict]',
token: str = None,
base_url: str = None,
workers: int = 4,
bot: Bot = None,
private_key: bytes = None,
private_key_password: bytes = None,
user_sig_handler: Callable = None,
request_kwargs: Dict[str, Any] = None,
persistence: 'BasePersistence' = None, # pylint: disable=E0601
defaults: 'Defaults' = None,
use_context: bool = True,
base_file_url: str = None,
):
...
@overload
def __init__(
self: 'Updater[CCT, UD, CD, BD]',
token: str = None,
base_url: str = None,
workers: int = 4,
bot: Bot = None,
private_key: bytes = None,
private_key_password: bytes = None,
user_sig_handler: Callable = None,
request_kwargs: Dict[str, Any] = None,
persistence: 'BasePersistence' = None,
defaults: 'Defaults' = None,
use_context: bool = True,
base_file_url: str = None,
context_types: ContextTypes[CCT, UD, CD, BD] = None,
):
...
@overload
def __init__(
self: 'Updater[CCT, UD, CD, BD]',
user_sig_handler: Callable = None,
dispatcher: Dispatcher[CCT, UD, CD, BD] = None,
):
...
def __init__( # type: ignore[no-untyped-def,misc]
self, self,
token: str = None, token: str = None,
base_url: str = None, base_url: str = None,
@ -137,8 +201,9 @@ class Updater:
persistence: 'BasePersistence' = None, persistence: 'BasePersistence' = None,
defaults: 'Defaults' = None, defaults: 'Defaults' = None,
use_context: bool = True, use_context: bool = True,
dispatcher: Dispatcher = None, dispatcher=None,
base_file_url: str = None, base_file_url: str = None,
context_types: ContextTypes[CCT, UD, CD, BD] = None,
): ):
if defaults and bot: if defaults and bot:
@ -161,10 +226,12 @@ class Updater:
raise ValueError('`dispatcher` and `bot` are mutually exclusive') raise ValueError('`dispatcher` and `bot` are mutually exclusive')
if persistence is not None: if persistence is not None:
raise ValueError('`dispatcher` and `persistence` are mutually exclusive') raise ValueError('`dispatcher` and `persistence` are mutually exclusive')
if workers is not None:
raise ValueError('`dispatcher` and `workers` are mutually exclusive')
if use_context != dispatcher.use_context: if use_context != dispatcher.use_context:
raise ValueError('`dispatcher` and `use_context` are mutually exclusive') raise ValueError('`dispatcher` and `use_context` are mutually exclusive')
if context_types is not None:
raise ValueError('`dispatcher` and `context_types` are mutually exclusive')
if workers is not None:
raise ValueError('`dispatcher` and `workers` are mutually exclusive')
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self._request = None self._request = None
@ -212,6 +279,7 @@ class Updater:
exception_event=self.__exception_event, exception_event=self.__exception_event,
persistence=persistence, persistence=persistence,
use_context=use_context, use_context=use_context,
context_types=context_types,
) )
self.job_queue.set_dispatcher(self.dispatcher) self.job_queue.set_dispatcher(self.dispatcher)
else: else:

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2021
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains custom typing aliases.
.. versionadded:: 13.6
"""
from typing import TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from telegram.ext import CallbackContext # noqa: F401
CCT = TypeVar('CCT', bound='CallbackContext')
"""An instance of :class:`telegram.ext.CallbackContext` or a custom subclass.
.. versionadded:: 13.6
"""
UD = TypeVar('UD')
"""Type of the user data for a single user.
.. versionadded:: 13.6
"""
CD = TypeVar('CD')
"""Type of the chat data for a single user.
.. versionadded:: 13.6
"""
BD = TypeVar('BD')
"""Type of the bot data.
.. versionadded:: 13.6
"""

View file

@ -18,11 +18,21 @@
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains custom typing aliases.""" """This module contains custom typing aliases."""
from pathlib import Path from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar, Union from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from telegram import InputFile from telegram import InputFile # noqa: F401
from telegram.utils.helpers import DefaultValue from telegram.utils.helpers import DefaultValue # noqa: F401
FileLike = Union[IO, 'InputFile'] FileLike = Union[IO, 'InputFile']
"""Either an open file handler or a :class:`telegram.InputFile`.""" """Either an open file handler or a :class:`telegram.InputFile`."""

View file

@ -22,6 +22,10 @@ import pytest
from telegram import Update, Message, Chat, User, TelegramError from telegram import Update, Message, Chat, User, TelegramError
from telegram.ext import CallbackContext from telegram.ext import CallbackContext
"""
CallbackContext.refresh_data is tested in TestBasePersistence
"""
class TestCallbackContext: class TestCallbackContext:
def test_slot_behaviour(self, cdp, recwarn, mro_slots): def test_slot_behaviour(self, cdp, recwarn, mro_slots):

View file

@ -0,0 +1,55 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2021
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import pytest
from telegram.ext import ContextTypes, CallbackContext
class SubClass(CallbackContext):
pass
class TestContextTypes:
def test_slot_behaviour(self, mro_slots):
instance = ContextTypes()
for attr in instance.__slots__:
assert getattr(instance, attr, 'err') != 'err', f"got extra slot '{attr}'"
assert len(mro_slots(instance)) == len(set(mro_slots(instance))), "duplicate slot"
with pytest.raises(AttributeError):
instance.custom
def test_data_init(self):
ct = ContextTypes(SubClass, int, float, bool)
assert ct.context is SubClass
assert ct.bot_data is int
assert ct.chat_data is float
assert ct.user_data is bool
with pytest.raises(ValueError, match='subclass of CallbackContext'):
ContextTypes(context=bool)
def test_data_assignment(self):
ct = ContextTypes()
with pytest.raises(AttributeError):
ct.bot_data = bool
with pytest.raises(AttributeError):
ct.user_data = bool
with pytest.raises(AttributeError):
ct.chat_data = bool

View file

@ -32,6 +32,7 @@ from telegram.ext import (
CallbackContext, CallbackContext,
JobQueue, JobQueue,
BasePersistence, BasePersistence,
ContextTypes,
) )
from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop
from telegram.utils.deprecate import TelegramDeprecationWarning from telegram.utils.deprecate import TelegramDeprecationWarning
@ -45,6 +46,10 @@ def dp2(bot):
yield from create_dp(bot) yield from create_dp(bot)
class CustomContext(CallbackContext):
pass
class TestDispatcher: class TestDispatcher:
message_update = Update( message_update = Update(
1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
@ -747,6 +752,15 @@ class TestDispatcher:
def update_conversation(self, name, key, new_state): def update_conversation(self, name, key, new_state):
pass pass
def refresh_bot_data(self, bot_data):
pass
def refresh_user_data(self, user_id, user_data):
pass
def refresh_chat_data(self, chat_id, chat_data):
pass
def callback(update, context): def callback(update, context):
pass pass
@ -807,6 +821,15 @@ class TestDispatcher:
def get_chat_data(self): def get_chat_data(self):
pass pass
def refresh_bot_data(self, bot_data):
pass
def refresh_user_data(self, user_id, user_data):
pass
def refresh_chat_data(self, chat_id, chat_data):
pass
def callback(update, context): def callback(update, context):
pass pass
@ -923,3 +946,62 @@ class TestDispatcher:
assert self.count == expected assert self.count == expected
finally: finally:
dp.bot.defaults = None dp.bot.defaults = None
def test_custom_context_init(self, bot):
cc = ContextTypes(
context=CustomContext,
user_data=int,
chat_data=float,
bot_data=complex,
)
dispatcher = Dispatcher(bot, Queue(), context_types=cc)
assert isinstance(dispatcher.user_data[1], int)
assert isinstance(dispatcher.chat_data[1], float)
assert isinstance(dispatcher.bot_data, complex)
def test_custom_context_error_handler(self, bot):
def error_handler(_, context):
self.received = (
type(context),
type(context.user_data),
type(context.chat_data),
type(context.bot_data),
)
dispatcher = Dispatcher(
bot,
Queue(),
context_types=ContextTypes(
context=CustomContext, bot_data=int, user_data=float, chat_data=complex
),
)
dispatcher.add_error_handler(error_handler)
dispatcher.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
dispatcher.process_update(self.message_update)
sleep(0.1)
assert self.received == (CustomContext, float, complex, int)
def test_custom_context_handler_callback(self, bot):
def callback(_, context):
self.received = (
type(context),
type(context.user_data),
type(context.chat_data),
type(context.bot_data),
)
dispatcher = Dispatcher(
bot,
Queue(),
context_types=ContextTypes(
context=CustomContext, bot_data=int, user_data=float, chat_data=complex
),
)
dispatcher.add_handler(MessageHandler(Filters.all, callback))
dispatcher.process_update(self.message_update)
sleep(0.1)
assert self.received == (CustomContext, float, complex, int)

View file

@ -29,7 +29,11 @@ import pytest
import pytz import pytz
from apscheduler.schedulers import SchedulerNotRunningError from apscheduler.schedulers import SchedulerNotRunningError
from flaky import flaky from flaky import flaky
from telegram.ext import JobQueue, Updater, Job, CallbackContext from telegram.ext import JobQueue, Updater, Job, CallbackContext, Dispatcher, ContextTypes
class CustomContext(CallbackContext):
pass
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
@ -519,3 +523,25 @@ class TestJobQueue:
assert len(caplog.records) == 1 assert len(caplog.records) == 1
rec = caplog.records[-1] rec = caplog.records[-1]
assert 'No error handlers are registered' in rec.getMessage() assert 'No error handlers are registered' in rec.getMessage()
def test_custom_context(self, bot, job_queue):
dispatcher = Dispatcher(
bot,
Queue(),
context_types=ContextTypes(
context=CustomContext, bot_data=int, user_data=float, chat_data=complex
),
)
job_queue.set_dispatcher(dispatcher)
def callback(context):
self.result = (
type(context),
context.user_data,
context.chat_data,
type(context.bot_data),
)
job_queue.run_once(callback, 0.1)
sleep(0.15)
assert self.result == (CustomContext, None, None, int)

View file

@ -46,6 +46,7 @@ from telegram.ext import (
DictPersistence, DictPersistence,
TypeHandler, TypeHandler,
JobQueue, JobQueue,
ContextTypes,
) )
@ -135,12 +136,16 @@ def bot_data():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def chat_data(): def chat_data():
return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}}) return defaultdict(
dict, {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}}
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def user_data(): def user_data():
return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}}) return defaultdict(
dict, {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}}
)
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
@ -172,6 +177,12 @@ def job_queue(bot):
class TestBasePersistence: class TestBasePersistence:
test_flag = False
@pytest.fixture(scope='function', autouse=True)
def reset(self):
self.test_flag = False
def test_slot_behaviour(self, bot_persistence, mro_slots, recwarn): def test_slot_behaviour(self, bot_persistence, mro_slots, recwarn):
inst = bot_persistence inst = bot_persistence
for attr in inst.__slots__: for attr in inst.__slots__:
@ -254,8 +265,17 @@ class TestBasePersistence:
u.dispatcher.chat_data[442233]['test5'] = 'test6' u.dispatcher.chat_data[442233]['test5'] = 'test6'
assert u.dispatcher.chat_data[442233]['test5'] == 'test6' assert u.dispatcher.chat_data[442233]['test5'] == 'test6'
@pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async'])
def test_dispatcher_integration_handlers( def test_dispatcher_integration_handlers(
self, caplog, bot, base_persistence, chat_data, user_data, bot_data self,
cdp,
caplog,
bot,
base_persistence,
chat_data,
user_data,
bot_data,
run_async,
): ):
def get_user_data(): def get_user_data():
return user_data return user_data
@ -269,112 +289,10 @@ class TestBasePersistence:
base_persistence.get_user_data = get_user_data base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data base_persistence.get_chat_data = get_chat_data
base_persistence.get_bot_data = get_bot_data base_persistence.get_bot_data = get_bot_data
# base_persistence.update_chat_data = lambda x: x base_persistence.refresh_bot_data = lambda x: x
# base_persistence.update_user_data = lambda x: x base_persistence.refresh_chat_data = lambda x, y: x
updater = Updater(bot=bot, persistence=base_persistence, use_context=True) base_persistence.refresh_user_data = lambda x, y: x
dp = updater.dispatcher
def callback_known_user(update, context):
if not context.user_data['test1'] == 'test2':
pytest.fail('user_data corrupt')
if not context.bot_data == bot_data:
pytest.fail('bot_data corrupt')
def callback_known_chat(update, context):
if not context.chat_data['test3'] == 'test4':
pytest.fail('chat_data corrupt')
if not context.bot_data == bot_data:
pytest.fail('bot_data corrupt')
def callback_unknown_user_or_chat(update, context):
if not context.user_data == {}:
pytest.fail('user_data corrupt')
if not context.chat_data == {}:
pytest.fail('chat_data corrupt')
if not context.bot_data == bot_data:
pytest.fail('bot_data corrupt')
context.user_data[1] = 'test7'
context.chat_data[2] = 'test8'
context.bot_data['test0'] = 'test0'
known_user = MessageHandler(
Filters.user(user_id=12345),
callback_known_user,
pass_chat_data=True,
pass_user_data=True,
)
known_chat = MessageHandler(
Filters.chat(chat_id=-67890),
callback_known_chat,
pass_chat_data=True,
pass_user_data=True,
)
unknown = MessageHandler(
Filters.all, callback_unknown_user_or_chat, pass_chat_data=True, pass_user_data=True
)
dp.add_handler(known_user)
dp.add_handler(known_chat)
dp.add_handler(unknown)
user1 = User(id=12345, first_name='test user', is_bot=False)
user2 = User(id=54321, first_name='test user', is_bot=False)
chat1 = Chat(id=-67890, type='group')
chat2 = Chat(id=-987654, type='group')
m = Message(1, None, chat2, from_user=user1)
u = Update(0, m)
with caplog.at_level(logging.ERROR):
dp.process_update(u)
rec = caplog.records[-1]
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
rec = caplog.records[-2]
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
rec = caplog.records[-3]
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
m.from_user = user2
m.chat = chat1
u = Update(1, m)
dp.process_update(u)
m.chat = chat2
u = Update(2, m)
def save_bot_data(data):
if 'test0' not in data:
pytest.fail()
def save_chat_data(data):
if -987654 not in data:
pytest.fail()
def save_user_data(data):
if 54321 not in data:
pytest.fail()
base_persistence.update_chat_data = save_chat_data
base_persistence.update_user_data = save_user_data
base_persistence.update_bot_data = save_bot_data
dp.process_update(u)
assert dp.user_data[54321][1] == 'test7'
assert dp.chat_data[-987654][2] == 'test8'
assert dp.bot_data['test0'] == 'test0'
def test_dispatcher_integration_handlers_run_async(
self, cdp, caplog, bot, base_persistence, chat_data, user_data, bot_data
):
def get_user_data():
return user_data
def get_chat_data():
return chat_data
def get_bot_data():
return bot_data
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
base_persistence.get_bot_data = get_bot_data
cdp.persistence = base_persistence cdp.persistence = base_persistence
cdp.user_data = user_data cdp.user_data = user_data
cdp.chat_data = chat_data cdp.chat_data = chat_data
@ -408,21 +326,21 @@ class TestBasePersistence:
callback_known_user, callback_known_user,
pass_chat_data=True, pass_chat_data=True,
pass_user_data=True, pass_user_data=True,
run_async=True, run_async=run_async,
) )
known_chat = MessageHandler( known_chat = MessageHandler(
Filters.chat(chat_id=-67890), Filters.chat(chat_id=-67890),
callback_known_chat, callback_known_chat,
pass_chat_data=True, pass_chat_data=True,
pass_user_data=True, pass_user_data=True,
run_async=True, run_async=run_async,
) )
unknown = MessageHandler( unknown = MessageHandler(
Filters.all, Filters.all,
callback_unknown_user_or_chat, callback_unknown_user_or_chat,
pass_chat_data=True, pass_chat_data=True,
pass_user_data=True, pass_user_data=True,
run_async=True, run_async=run_async,
) )
cdp.add_handler(known_user) cdp.add_handler(known_user)
cdp.add_handler(known_chat) cdp.add_handler(known_chat)
@ -437,12 +355,16 @@ class TestBasePersistence:
cdp.process_update(u) cdp.process_update(u)
sleep(0.1) sleep(0.1)
rec = caplog.records[-1]
assert rec.getMessage() == 'No error handlers are registered, logging exception.' # In base_persistence.update_*_data we currently just raise NotImplementedError
assert rec.levelname == 'ERROR' # This makes sure that this doesn't break the processing and is properly handled by
rec = caplog.records[-2] # the error handler
# We override `update_*_data` further below.
assert len(caplog.records) == 3
for rec in caplog.records:
assert rec.getMessage() == 'No error handlers are registered, logging exception.' assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR' assert rec.levelname == 'ERROR'
m.from_user = user2 m.from_user = user2
m.chat = chat1 m.chat = chat1
u = Update(1, m) u = Update(1, m)
@ -473,6 +395,105 @@ class TestBasePersistence:
assert cdp.chat_data[-987654][2] == 'test8' assert cdp.chat_data[-987654][2] == 'test8'
assert cdp.bot_data['test0'] == 'test0' assert cdp.bot_data['test0'] == 'test0'
@pytest.mark.parametrize(
'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False']
)
@pytest.mark.parametrize(
'store_chat_data', [True, False], ids=['store_chat_data-True', 'store_chat_data-False']
)
@pytest.mark.parametrize(
'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False']
)
@pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async'])
def test_persistence_dispatcher_integration_refresh_data(
self,
cdp,
base_persistence,
chat_data,
bot_data,
user_data,
store_bot_data,
store_chat_data,
store_user_data,
run_async,
):
base_persistence.refresh_bot_data = lambda x: x.setdefault(
'refreshed', x.get('refreshed', 0) + 1
)
# x is the user/chat_id
base_persistence.refresh_chat_data = lambda x, y: y.setdefault('refreshed', x)
base_persistence.refresh_user_data = lambda x, y: y.setdefault('refreshed', x)
base_persistence.store_bot_data = store_bot_data
base_persistence.store_chat_data = store_chat_data
base_persistence.store_user_data = store_user_data
cdp.persistence = base_persistence
self.test_flag = True
def callback_with_user_and_chat(update, context):
if store_user_data:
if context.user_data.get('refreshed') != update.effective_user.id:
self.test_flag = 'user_data was not refreshed'
else:
if 'refreshed' in context.user_data:
self.test_flag = 'user_data was wrongly refreshed'
if store_chat_data:
if context.chat_data.get('refreshed') != update.effective_chat.id:
self.test_flag = 'chat_data was not refreshed'
else:
if 'refreshed' in context.chat_data:
self.test_flag = 'chat_data was wrongly refreshed'
if store_bot_data:
if context.bot_data.get('refreshed') != 1:
self.test_flag = 'bot_data was not refreshed'
else:
if 'refreshed' in context.bot_data:
self.test_flag = 'bot_data was wrongly refreshed'
def callback_without_user_and_chat(_, context):
if store_bot_data:
if context.bot_data.get('refreshed') != 1:
self.test_flag = 'bot_data was not refreshed'
else:
if 'refreshed' in context.bot_data:
self.test_flag = 'bot_data was wrongly refreshed'
with_user_and_chat = MessageHandler(
Filters.user(user_id=12345),
callback_with_user_and_chat,
pass_chat_data=True,
pass_user_data=True,
run_async=run_async,
)
without_user_and_chat = MessageHandler(
Filters.all,
callback_without_user_and_chat,
pass_chat_data=True,
pass_user_data=True,
run_async=run_async,
)
cdp.add_handler(with_user_and_chat)
cdp.add_handler(without_user_and_chat)
user = User(id=12345, first_name='test user', is_bot=False)
chat = Chat(id=-987654, type='group')
m = Message(1, None, chat, from_user=user)
# has user and chat
u = Update(0, m)
cdp.process_update(u)
assert self.test_flag is True
# has neither user nor hat
m.from_user = None
m.chat = None
u = Update(1, m)
cdp.process_update(u)
assert self.test_flag is True
sleep(0.1)
def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog): def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog):
# Updates used with TypeHandler doesn't necessarily have the proper attributes for # Updates used with TypeHandler doesn't necessarily have the proper attributes for
# persistence, makes sure it works anyways # persistence, makes sure it works anyways
@ -816,6 +837,10 @@ def update(bot):
return Update(0, message=message) return Update(0, message=message)
class CustomMapping(defaultdict):
pass
class TestPicklePersistence: class TestPicklePersistence:
def test_slot_behaviour(self, mro_slots, recwarn, pickle_persistence): def test_slot_behaviour(self, mro_slots, recwarn, pickle_persistence):
inst = pickle_persistence inst = pickle_persistence
@ -986,25 +1011,34 @@ class TestPicklePersistence:
def test_updating_multi_file(self, pickle_persistence, good_pickle_files): def test_updating_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data() user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10' user_data[12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.user_data == user_data assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321]) pickle_persistence.update_user_data(12345, user_data[12345])
user_data[12345]['test3']['test4'] = 'test7'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(12345, user_data[12345])
assert pickle_persistence.user_data == user_data assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f: with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)) user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data() chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10' chat_data[-12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.chat_data == chat_data assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321]) pickle_persistence.update_chat_data(-12345, chat_data[-12345])
chat_data[-12345]['test3']['test4'] = 'test7'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(-12345, chat_data[-12345])
assert pickle_persistence.chat_data == chat_data assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f: with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)) chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data assert chat_data_test == chat_data
bot_data = pickle_persistence.get_bot_data() bot_data = pickle_persistence.get_bot_data()
bot_data['test6'] = 'test 7' bot_data['test3']['test4'] = 'test6'
assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data)
bot_data['test3']['test4'] = 'test7'
assert not pickle_persistence.bot_data == bot_data assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data) pickle_persistence.update_bot_data(bot_data)
assert pickle_persistence.bot_data == bot_data assert pickle_persistence.bot_data == bot_data
@ -1031,25 +1065,34 @@ class TestPicklePersistence:
pickle_persistence.single_file = True pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data() user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10' user_data[12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.user_data == user_data assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321]) pickle_persistence.update_user_data(12345, user_data[12345])
user_data[12345]['test3']['test4'] = 'test7'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(12345, user_data[12345])
assert pickle_persistence.user_data == user_data assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f: with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data']) user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data() chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10' chat_data[-12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.chat_data == chat_data assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321]) pickle_persistence.update_chat_data(-12345, chat_data[-12345])
chat_data[-12345]['test3']['test4'] = 'test7'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(-12345, chat_data[-12345])
assert pickle_persistence.chat_data == chat_data assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f: with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data assert chat_data_test == chat_data
bot_data = pickle_persistence.get_bot_data() bot_data = pickle_persistence.get_bot_data()
bot_data['test6'] = 'test 7' bot_data['test3']['test4'] = 'test6'
assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data)
bot_data['test3']['test4'] = 'test7'
assert not pickle_persistence.bot_data == bot_data assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data) pickle_persistence.update_bot_data(bot_data)
assert pickle_persistence.bot_data == bot_data assert pickle_persistence.bot_data == bot_data
@ -1418,6 +1461,39 @@ class TestPicklePersistence:
user_data = pickle_persistence.get_user_data() user_data = pickle_persistence.get_user_data()
assert user_data[789] == {'test3': '123'} assert user_data[789] == {'test3': '123'}
@pytest.mark.parametrize('singlefile', [True, False])
@pytest.mark.parametrize('ud', [int, float, complex])
@pytest.mark.parametrize('cd', [int, float, complex])
@pytest.mark.parametrize('bd', [int, float, complex])
def test_with_context_types(self, ud, cd, bd, singlefile):
cc = ContextTypes(user_data=ud, chat_data=cd, bot_data=bd)
persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc)
assert isinstance(persistence.get_user_data()[1], ud)
assert persistence.get_user_data()[1] == 0
assert isinstance(persistence.get_chat_data()[1], cd)
assert persistence.get_chat_data()[1] == 0
assert isinstance(persistence.get_bot_data(), bd)
assert persistence.get_bot_data() == 0
persistence.user_data = None
persistence.chat_data = None
persistence.update_user_data(1, ud(1))
persistence.update_chat_data(1, cd(1))
persistence.update_bot_data(bd(1))
assert persistence.get_user_data()[1] == 1
assert persistence.get_chat_data()[1] == 1
assert persistence.get_bot_data() == 1
persistence.flush()
persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc)
assert isinstance(persistence.get_user_data()[1], ud)
assert persistence.get_user_data()[1] == 1
assert isinstance(persistence.get_chat_data()[1], cd)
assert persistence.get_chat_data()[1] == 1
assert isinstance(persistence.get_bot_data(), bd)
assert persistence.get_bot_data() == 1
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def user_data_json(user_data): def user_data_json(user_data):
@ -1560,7 +1636,7 @@ class TestDictPersistence:
assert dict_persistence.bot_data_json == bot_data_json assert dict_persistence.bot_data_json == bot_data_json
assert dict_persistence.conversations_json == conversations_json assert dict_persistence.conversations_json == conversations_json
def test_json_changes( def test_updating(
self, self,
user_data, user_data,
user_data_json, user_data_json,
@ -1577,35 +1653,59 @@ class TestDictPersistence:
bot_data_json=bot_data_json, bot_data_json=bot_data_json,
conversations_json=conversations_json, conversations_json=conversations_json,
) )
user_data_two = user_data.copy()
user_data_two.update({4: {5: 6}})
dict_persistence.update_user_data(4, {5: 6})
assert dict_persistence.user_data == user_data_two
assert dict_persistence.user_data_json != user_data_json
assert dict_persistence.user_data_json == json.dumps(user_data_two)
chat_data_two = chat_data.copy() user_data = dict_persistence.get_user_data()
chat_data_two.update({7: {8: 9}}) user_data[12345]['test3']['test4'] = 'test6'
dict_persistence.update_chat_data(7, {8: 9}) assert not dict_persistence.user_data == user_data
assert dict_persistence.chat_data == chat_data_two assert not dict_persistence.user_data_json == json.dumps(user_data)
assert dict_persistence.chat_data_json != chat_data_json dict_persistence.update_user_data(12345, user_data[12345])
assert dict_persistence.chat_data_json == json.dumps(chat_data_two) user_data[12345]['test3']['test4'] = 'test7'
assert not dict_persistence.user_data == user_data
assert not dict_persistence.user_data_json == json.dumps(user_data)
dict_persistence.update_user_data(12345, user_data[12345])
assert dict_persistence.user_data == user_data
assert dict_persistence.user_data_json == json.dumps(user_data)
bot_data_two = bot_data.copy() chat_data = dict_persistence.get_chat_data()
bot_data_two.update({'7': {'8': '9'}}) chat_data[-12345]['test3']['test4'] = 'test6'
bot_data['7'] = {'8': '9'} assert not dict_persistence.chat_data == chat_data
assert not dict_persistence.chat_data_json == json.dumps(chat_data)
dict_persistence.update_chat_data(-12345, chat_data[-12345])
chat_data[-12345]['test3']['test4'] = 'test7'
assert not dict_persistence.chat_data == chat_data
assert not dict_persistence.chat_data_json == json.dumps(chat_data)
dict_persistence.update_chat_data(-12345, chat_data[-12345])
assert dict_persistence.chat_data == chat_data
assert dict_persistence.chat_data_json == json.dumps(chat_data)
bot_data = dict_persistence.get_bot_data()
bot_data['test3']['test4'] = 'test6'
assert not dict_persistence.bot_data == bot_data
assert not dict_persistence.bot_data_json == json.dumps(bot_data)
dict_persistence.update_bot_data(bot_data) dict_persistence.update_bot_data(bot_data)
assert dict_persistence.bot_data == bot_data_two bot_data['test3']['test4'] = 'test7'
assert dict_persistence.bot_data_json != bot_data_json assert not dict_persistence.bot_data == bot_data
assert dict_persistence.bot_data_json == json.dumps(bot_data_two) assert not dict_persistence.bot_data_json == json.dumps(bot_data)
dict_persistence.update_bot_data(bot_data)
assert dict_persistence.bot_data == bot_data
assert dict_persistence.bot_data_json == json.dumps(bot_data)
conversations_two = conversations.copy() conversation1 = dict_persistence.get_conversations('name1')
conversations_two.update({'name4': {(1, 2): 3}}) conversation1[(123, 123)] = 5
dict_persistence.update_conversation('name4', (1, 2), 3) assert not dict_persistence.conversations['name1'] == conversation1
assert dict_persistence.conversations == conversations_two dict_persistence.update_conversation('name1', (123, 123), 5)
assert dict_persistence.conversations_json != conversations_json assert dict_persistence.conversations['name1'] == conversation1
print(dict_persistence.conversations_json)
conversations['name1'][(123, 123)] = 5
assert dict_persistence.conversations_json == encode_conversations_to_json(conversations)
assert dict_persistence.get_conversations('name1') == conversation1
dict_persistence._conversations = None
dict_persistence.update_conversation('name1', (123, 123), 5)
assert dict_persistence.conversations['name1'] == {(123, 123): 5}
assert dict_persistence.get_conversations('name1') == {(123, 123): 5}
assert dict_persistence.conversations_json == encode_conversations_to_json( assert dict_persistence.conversations_json == encode_conversations_to_json(
conversations_two {"name1": {(123, 123): 5}}
) )
def test_with_handler(self, bot, update): def test_with_handler(self, bot, update):

View file

@ -31,6 +31,7 @@ excluded = {
'Days', 'Days',
'telegram.deprecate', 'telegram.deprecate',
'TelegramDecryptionError', 'TelegramDecryptionError',
'ContextTypes',
} # These modules/classes intentionally don't have __dict__. } # These modules/classes intentionally don't have __dict__.

View file

@ -613,6 +613,11 @@ class TestUpdater:
with pytest.raises(ValueError): with pytest.raises(ValueError):
Updater(dispatcher=dispatcher, use_context=use_context) Updater(dispatcher=dispatcher, use_context=use_context)
def test_mutual_exclude_custom_context_dispatcher(self):
dispatcher = Dispatcher(None, None)
with pytest.raises(ValueError):
Updater(dispatcher=dispatcher, context_types=True)
def test_defaults_warning(self, bot): def test_defaults_warning(self, bot):
with pytest.warns(TelegramDeprecationWarning, match='no effect when a Bot is passed'): with pytest.warns(TelegramDeprecationWarning, match='no effect when a Bot is passed'):
Updater(bot=bot, defaults=Defaults()) Updater(bot=bot, defaults=Defaults())