diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index ba3c919ef..f9dbe6885 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -60,7 +60,7 @@ jobs:
shell: bash --noprofile --norc {0}
- name: Submit coverage
- uses: codecov/codecov-action@v1.0.13
+ uses: codecov/codecov-action@v1
with:
env_vars: OS,PYTHON
name: ${{ matrix.os }}-${{ matrix.python-version }}
@@ -79,7 +79,7 @@ jobs:
run:
git submodule update --init --recursive
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
@@ -108,7 +108,7 @@ jobs:
run:
git submodule update --init --recursive
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
diff --git a/docs/source/telegram.ext.contexttypes.rst b/docs/source/telegram.ext.contexttypes.rst
new file mode 100644
index 000000000..d0cc0a29a
--- /dev/null
+++ b/docs/source/telegram.ext.contexttypes.rst
@@ -0,0 +1,8 @@
+:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/contexttypes.py
+
+telegram.ext.ContextTypes
+=========================
+
+.. autoclass:: telegram.ext.ContextTypes
+ :members:
+ :show-inheritance:
diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst
index 9bd855a87..316950446 100644
--- a/docs/source/telegram.ext.rst
+++ b/docs/source/telegram.ext.rst
@@ -7,11 +7,12 @@ telegram.ext package
telegram.ext.dispatcher
telegram.ext.dispatcherhandlerstop
telegram.ext.callbackcontext
- telegram.ext.defaults
telegram.ext.job
telegram.ext.jobqueue
telegram.ext.messagequeue
telegram.ext.delayqueue
+ telegram.ext.contexttypes
+ telegram.ext.defaults
Handlers
--------
@@ -51,4 +52,5 @@ utils
.. toctree::
- telegram.ext.utils.promise
\ No newline at end of file
+ telegram.ext.utils.promise
+ telegram.ext.utils.types
\ No newline at end of file
diff --git a/docs/source/telegram.ext.utils.types.rst b/docs/source/telegram.ext.utils.types.rst
new file mode 100644
index 000000000..5c501ecf8
--- /dev/null
+++ b/docs/source/telegram.ext.utils.types.rst
@@ -0,0 +1,8 @@
+:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/utils/types.py
+
+telegram.ext.utils.types Module
+================================
+
+.. automodule:: telegram.ext.utils.types
+ :members:
+ :show-inheritance:
diff --git a/examples/README.md b/examples/README.md
index 5b05c53ef..7d8f19225 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -47,7 +47,10 @@ A basic example of a bot that can accept payments. Don't forget to enable and co
A basic example on how to set up a custom error handler.
### [`chatmemberbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/chatmemberbot.py)
-A basic example on how `(my_)chat_member` updates can be used.
+A basic example on how `(my_)chat_member` updates can be used.
+
+### [`contexttypesbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/contexttypesbot.py)
+This example showcases how `telegram.ext.ContextTypes` can be used to customize the `context` argument of handler and job callbacks.
## Pure API
The [`rawapibot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/rawapibot.py) example uses only the pure, "bare-metal" API wrapper.
diff --git a/examples/contexttypesbot.py b/examples/contexttypesbot.py
new file mode 100644
index 000000000..cfe485a61
--- /dev/null
+++ b/examples/contexttypesbot.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+# pylint: disable=C0116,W0613
+# This program is dedicated to the public domain under the CC0 license.
+
+"""
+Simple Bot to showcase `telegram.ext.ContextTypes`.
+
+Usage:
+Press Ctrl-C on the command line or send a signal to the process to stop the
+bot.
+"""
+
+from collections import defaultdict
+from typing import DefaultDict, Optional, Set
+
+from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, ParseMode
+from telegram.ext import (
+ Updater,
+ CommandHandler,
+ CallbackContext,
+ ContextTypes,
+ CallbackQueryHandler,
+ TypeHandler,
+ Dispatcher,
+)
+
+
+class ChatData:
+ """Custom class for chat_data. Here we store data per message."""
+
+ def __init__(self) -> None:
+ self.clicks_per_message: DefaultDict[int, int] = defaultdict(int)
+
+
+# The [dict, ChatData, dict] is for type checkers like mypy
+class CustomContext(CallbackContext[dict, ChatData, dict]):
+ """Custom class for context."""
+
+ def __init__(self, dispatcher: Dispatcher):
+ super().__init__(dispatcher=dispatcher)
+ self._message_id: Optional[int] = None
+
+ @property
+ def bot_user_ids(self) -> Set[int]:
+ """Custom shortcut to access a value stored in the bot_data dict"""
+ return self.bot_data.setdefault('user_ids', set())
+
+ @property
+ def message_clicks(self) -> Optional[int]:
+ """Access the number of clicks for the message this context object was built for."""
+ if self._message_id:
+ return self.chat_data.clicks_per_message[self._message_id]
+ return None
+
+ @message_clicks.setter
+ def message_clicks(self, value: int) -> None:
+ """Allow to change the count"""
+ if not self._message_id:
+ raise RuntimeError('There is no message associated with this context obejct.')
+ self.chat_data.clicks_per_message[self._message_id] = value
+
+ @classmethod
+ def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CustomContext':
+ """Override from_update to set _message_id."""
+ # Make sure to call super()
+ context = super().from_update(update, dispatcher)
+
+ if context.chat_data and isinstance(update, Update) and update.effective_message:
+ context._message_id = update.effective_message.message_id # pylint: disable=W0212
+
+ # Remember to return the object
+ return context
+
+
+def start(update: Update, context: CustomContext) -> None:
+ """Display a message with a button."""
+ update.message.reply_html(
+ 'This button was clicked 0 times.',
+ reply_markup=InlineKeyboardMarkup.from_button(
+ InlineKeyboardButton(text='Click me!', callback_data='button')
+ ),
+ )
+
+
+def count_click(update: Update, context: CustomContext) -> None:
+ """Update the click count for the message."""
+ context.message_clicks += 1
+ update.callback_query.answer()
+ update.effective_message.edit_text(
+ f'This button was clicked {context.message_clicks} times.',
+ reply_markup=InlineKeyboardMarkup.from_button(
+ InlineKeyboardButton(text='Click me!', callback_data='button')
+ ),
+ parse_mode=ParseMode.HTML,
+ )
+
+
+def print_users(update: Update, context: CustomContext) -> None:
+ """Show which users have been using this bot."""
+ update.message.reply_text(
+ 'The following user IDs have used this bot: '
+ f'{", ".join(map(str, context.bot_user_ids))}'
+ )
+
+
+def track_users(update: Update, context: CustomContext) -> None:
+ """Store the user id of the incoming update, if any."""
+ if update.effective_user:
+ context.bot_user_ids.add(update.effective_user.id)
+
+
+def main() -> None:
+ """Run the bot."""
+ context_types = ContextTypes(context=CustomContext, chat_data=ChatData)
+ updater = Updater("TOKEN", context_types=context_types)
+
+ dispatcher = updater.dispatcher
+ # run track_users in its own group to not interfere with the user handlers
+ dispatcher.add_handler(TypeHandler(Update, track_users), group=-1)
+ dispatcher.add_handler(CommandHandler("start", start))
+ dispatcher.add_handler(CallbackQueryHandler(count_click))
+ dispatcher.add_handler(CommandHandler("print_users", print_users))
+
+ updater.start_polling()
+ updater.idle()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/setup.cfg b/setup.cfg
index dd6d88012..a38acfd36 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -44,6 +44,7 @@ omit =
[coverage:report]
exclude_lines =
if TYPE_CHECKING:
+ ...
[mypy]
warn_unused_ignores = True
diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py
index f536c2be2..93f561514 100644
--- a/telegram/ext/__init__.py
+++ b/telegram/ext/__init__.py
@@ -16,6 +16,7 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
+# pylint: disable=C0413
"""Extensions over the Telegram Bot API to facilitate bot making"""
from .basepersistence import BasePersistence
@@ -23,7 +24,20 @@ from .picklepersistence import PicklePersistence
from .dictpersistence import DictPersistence
from .handler import Handler
from .callbackcontext import CallbackContext
+from .contexttypes import ContextTypes
from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async
+
+# https://bugs.python.org/issue41451, fixed on 3.7+, doesn't actually remove slots
+# try-except is just here in case the __init__ is called twice (like in the tests)
+# this block is also the reason for the pylint-ignore at the top of the file
+try:
+ del Dispatcher.__slots__ # type: ignore[has-type]
+except AttributeError as exc:
+ if str(exc) == '__slots__':
+ pass
+ else:
+ raise exc
+
from .jobqueue import JobQueue, Job
from .updater import Updater
from .callbackqueryhandler import CallbackQueryHandler
@@ -47,38 +61,39 @@ from .chatmemberhandler import ChatMemberHandler
from .defaults import Defaults
__all__ = (
- 'Dispatcher',
- 'JobQueue',
- 'Job',
- 'Updater',
+ 'BaseFilter',
+ 'BasePersistence',
+ 'CallbackContext',
'CallbackQueryHandler',
+ 'ChatMemberHandler',
'ChosenInlineResultHandler',
'CommandHandler',
+ 'ContextTypes',
+ 'ConversationHandler',
+ 'Defaults',
+ 'DelayQueue',
+ 'DictPersistence',
+ 'Dispatcher',
+ 'DispatcherHandlerStop',
+ 'Filters',
'Handler',
'InlineQueryHandler',
- 'MessageHandler',
- 'BaseFilter',
+ 'Job',
+ 'JobQueue',
'MessageFilter',
- 'UpdateFilter',
- 'Filters',
+ 'MessageHandler',
+ 'MessageQueue',
+ 'PicklePersistence',
+ 'PollAnswerHandler',
+ 'PollHandler',
+ 'PreCheckoutQueryHandler',
+ 'PrefixHandler',
'RegexHandler',
+ 'ShippingQueryHandler',
'StringCommandHandler',
'StringRegexHandler',
'TypeHandler',
- 'ConversationHandler',
- 'PreCheckoutQueryHandler',
- 'ShippingQueryHandler',
- 'MessageQueue',
- 'DelayQueue',
- 'DispatcherHandlerStop',
+ 'UpdateFilter',
+ 'Updater',
'run_async',
- 'CallbackContext',
- 'BasePersistence',
- 'PicklePersistence',
- 'DictPersistence',
- 'PrefixHandler',
- 'PollAnswerHandler',
- 'PollHandler',
- 'ChatMemberHandler',
- 'Defaults',
)
diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py
index c0c248e5c..94453bec5 100644
--- a/telegram/ext/basepersistence.py
+++ b/telegram/ext/basepersistence.py
@@ -21,16 +21,17 @@ import warnings
from sys import version_info as py_ver
from abc import ABC, abstractmethod
from copy import copy
-from typing import DefaultDict, Dict, Optional, Tuple, cast, ClassVar
+from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, DefaultDict
from telegram.utils.deprecate import set_new_attribute_deprecated
from telegram import Bot
from telegram.utils.types import ConversationDict
+from telegram.ext.utils.types import UD, CD, BD
-class BasePersistence(ABC):
+class BasePersistence(Generic[UD, CD, BD], ABC):
"""Interface class for adding persistence to your bot.
Subclass this object for different implementations of a persistent bot.
@@ -38,16 +39,20 @@ class BasePersistence(ABC):
* :meth:`get_bot_data`
* :meth:`update_bot_data`
+ * :meth:`refresh_bot_data`
* :meth:`get_chat_data`
* :meth:`update_chat_data`
+ * :meth:`refresh_chat_data`
* :meth:`get_user_data`
* :meth:`update_user_data`
+ * :meth:`refresh_user_data`
* :meth:`get_conversations`
* :meth:`update_conversation`
* :meth:`flush`
If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if
- ``store_bot_data=False``, you don't need :meth:`get_bot_data` and :meth:`update_bot_data`.
+ ``store_bot_data=False``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or
+ :meth:`refresh_bot_data`.
Warning:
Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and
@@ -93,6 +98,10 @@ class BasePersistence(ABC):
def __new__(
cls, *args: object, **kwargs: object # pylint: disable=W0613
) -> 'BasePersistence':
+ """This overrides the get_* and update_* methods to use insert/replace_bot.
+ That has the side effect that we always pass deepcopied data to those methods, so in
+ Pickle/DictPersistence we don't have to worry about copying the data again.
+ """
instance = super().__new__(cls)
get_user_data = instance.get_user_data
get_chat_data = instance.get_chat_data
@@ -101,22 +110,22 @@ class BasePersistence(ABC):
update_chat_data = instance.update_chat_data
update_bot_data = instance.update_bot_data
- def get_user_data_insert_bot() -> DefaultDict[int, Dict[object, object]]:
+ def get_user_data_insert_bot() -> DefaultDict[int, UD]:
return instance.insert_bot(get_user_data())
- def get_chat_data_insert_bot() -> DefaultDict[int, Dict[object, object]]:
+ def get_chat_data_insert_bot() -> DefaultDict[int, CD]:
return instance.insert_bot(get_chat_data())
- def get_bot_data_insert_bot() -> Dict[object, object]:
+ def get_bot_data_insert_bot() -> BD:
return instance.insert_bot(get_bot_data())
- def update_user_data_replace_bot(user_id: int, data: Dict) -> None:
+ def update_user_data_replace_bot(user_id: int, data: UD) -> None:
return update_user_data(user_id, instance.replace_bot(data))
- def update_chat_data_replace_bot(chat_id: int, data: Dict) -> None:
+ def update_chat_data_replace_bot(chat_id: int, data: CD) -> None:
return update_chat_data(chat_id, instance.replace_bot(data))
- def update_bot_data_replace_bot(data: Dict) -> None:
+ def update_bot_data_replace_bot(data: BD) -> None:
return update_bot_data(instance.replace_bot(data))
# We want to ignore TGDeprecation warnings so we use obj.__setattr__. Adds to __dict__
@@ -334,33 +343,33 @@ class BasePersistence(ABC):
return obj
@abstractmethod
- def get_user_data(self) -> DefaultDict[int, Dict[object, object]]:
+ def get_user_data(self) -> DefaultDict[int, UD]:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``user_data`` if stored, or an empty
- ``defaultdict(dict)``.
+ :obj:`defaultdict(telegram.ext.utils.types.UD)` with integer keys.
Returns:
- :obj:`defaultdict`: The restored user data.
+ DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data.
"""
@abstractmethod
- def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]:
+ def get_chat_data(self) -> DefaultDict[int, CD]:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``chat_data`` if stored, or an empty
- ``defaultdict(dict)``.
+ :obj:`defaultdict(telegram.ext.utils.types.CD)` with integer keys.
Returns:
- :obj:`defaultdict`: The restored chat data.
+ DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data.
"""
@abstractmethod
- def get_bot_data(self) -> Dict[object, object]:
+ def get_bot_data(self) -> BD:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the ``bot_data`` if stored, or an empty
- :obj:`dict`.
+ :class:`telegram.ext.utils.types.BD`.
Returns:
- :obj:`dict`: The restored bot data.
+ :class:`telegram.ext.utils.types.BD`: The restored bot data.
"""
@abstractmethod
@@ -391,32 +400,70 @@ class BasePersistence(ABC):
"""
@abstractmethod
- def update_user_data(self, user_id: int, data: Dict) -> None:
+ def update_user_data(self, user_id: int, data: UD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
+ data (:class:`telegram.ext.utils.types.UD`): The
+ :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
"""
@abstractmethod
- def update_chat_data(self, chat_id: int, data: Dict) -> None:
+ def update_chat_data(self, chat_id: int, data: CD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
+ data (:class:`telegram.ext.utils.types.CD`): The
+ :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
"""
@abstractmethod
- def update_bot_data(self, data: Dict) -> None:
+ def update_bot_data(self, data: BD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data` .
+ data (:class:`telegram.ext.utils.types.BD`): The
+ :attr:`telegram.ext.dispatcher.bot_data`.
+ """
+
+ def refresh_user_data(self, user_id: int, user_data: UD) -> None:
+ """Will be called by the :class:`telegram.ext.Dispatcher` before passing the
+ :attr:`user_data` to a callback. Can be used to update data stored in :attr:`user_data`
+ from an external source.
+
+ .. versionadded:: 13.6
+
+ Args:
+ user_id (:obj:`int`): The user ID this :attr:`user_data` is associated with.
+ user_data (:class:`telegram.ext.utils.types.UD`): The ``user_data`` of a single user.
+ """
+
+ def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
+ """Will be called by the :class:`telegram.ext.Dispatcher` before passing the
+ :attr:`chat_data` to a callback. Can be used to update data stored in :attr:`chat_data`
+ from an external source.
+
+ .. versionadded:: 13.6
+
+ Args:
+ chat_id (:obj:`int`): The chat ID this :attr:`chat_data` is associated with.
+ chat_data (:class:`telegram.ext.utils.types.CD`): The ``chat_data`` of a single chat.
+ """
+
+ def refresh_bot_data(self, bot_data: BD) -> None:
+ """Will be called by the :class:`telegram.ext.Dispatcher` before passing the
+ :attr:`bot_data` to a callback. Can be used to update data stored in :attr:`bot_data`
+ from an external source.
+
+ .. versionadded:: 13.6
+
+ Args:
+ bot_data (:class:`telegram.ext.utils.types.BD`): The ``bot_data``.
"""
def flush(self) -> None:
diff --git a/telegram/ext/callbackcontext.py b/telegram/ext/callbackcontext.py
index d27840b2e..626af5f83 100644
--- a/telegram/ext/callbackcontext.py
+++ b/telegram/ext/callbackcontext.py
@@ -19,16 +19,31 @@
# pylint: disable=R0201
"""This module contains the CallbackContext class."""
from queue import Queue
-from typing import TYPE_CHECKING, Dict, List, Match, NoReturn, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ List,
+ Match,
+ NoReturn,
+ Optional,
+ Tuple,
+ Union,
+ Generic,
+ Type,
+ TypeVar,
+)
from telegram import Update
+from telegram.ext.utils.types import UD, CD, BD
if TYPE_CHECKING:
from telegram import Bot
from telegram.ext import Dispatcher, Job, JobQueue
+CC = TypeVar('CC', bound='CallbackContext')
-class CallbackContext:
+
+class CallbackContext(Generic[UD, CD, BD]):
"""
This is a context object passed to the callback called by :class:`telegram.ext.Handler`
or by the :class:`telegram.ext.Dispatcher` in an error handler added by
@@ -50,6 +65,9 @@ class CallbackContext:
almost certainly execute the callbacks for an update out of order, and the attributes
that you think you added will not be present.
+ Args:
+ dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this context.
+
Attributes:
matches (List[:obj:`re match object`]): Optional. If the associated update originated from
a regex-supported handler or had a :class:`Filters.regex`, this will contain a list of
@@ -75,9 +93,8 @@ class CallbackContext:
__slots__ = (
'_dispatcher',
- '_bot_data',
- '_chat_data',
- '_user_data',
+ '_chat_id_and_data',
+ '_user_id_and_data',
'args',
'matches',
'error',
@@ -97,9 +114,8 @@ class CallbackContext:
'CallbackContext should not be used with a non context aware ' 'dispatcher!'
)
self._dispatcher = dispatcher
- self._bot_data = dispatcher.bot_data
- self._chat_data: Optional[Dict[object, object]] = None
- self._user_data: Optional[Dict[object, object]] = None
+ self._chat_id_and_data: Optional[Tuple[int, CD]] = None
+ self._user_id_and_data: Optional[Tuple[int, UD]] = None
self.args: Optional[List[str]] = None
self.matches: Optional[List[Match]] = None
self.error: Optional[Exception] = None
@@ -113,11 +129,11 @@ class CallbackContext:
return self._dispatcher
@property
- def bot_data(self) -> Dict:
+ def bot_data(self) -> BD:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update it will be the same ``dict``.
"""
- return self._bot_data
+ return self.dispatcher.bot_data
@bot_data.setter
def bot_data(self, value: object) -> NoReturn:
@@ -126,7 +142,7 @@ class CallbackContext:
)
@property
- def chat_data(self) -> Optional[Dict]:
+ def chat_data(self) -> Optional[CD]:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update from the same chat id it will be the same ``dict``.
@@ -136,7 +152,9 @@ class CallbackContext:
`_.
"""
- return self._chat_data
+ if self._chat_id_and_data:
+ return self._chat_id_and_data[1]
+ return None
@chat_data.setter
def chat_data(self, value: object) -> NoReturn:
@@ -145,11 +163,13 @@ class CallbackContext:
)
@property
- def user_data(self) -> Optional[Dict]:
+ def user_data(self) -> Optional[UD]:
""":obj:`dict`: Optional. A dict that can be used to keep any data in. For each
update from the same user it will be the same ``dict``.
"""
- return self._user_data
+ if self._user_id_and_data:
+ return self._user_id_and_data[1]
+ return None
@user_data.setter
def user_data(self, value: object) -> NoReturn:
@@ -157,15 +177,32 @@ class CallbackContext:
"You can not assign a new value to user_data, see https://git.io/Jt6ic"
)
+ def refresh_data(self) -> None:
+ """If :attr:`dispatcher` uses persistence, calls
+ :meth:`telegram.ext.BasePersistence.refresh_bot_data` on :attr:`bot_data`,
+ :meth:`telegram.ext.BasePersistence.refresh_chat_data` on :attr:`chat_data` and
+ :meth:`telegram.ext.BasePersistence.refresh_user_data` on :attr:`user_data`, if
+ appropriate.
+
+ .. versionadded:: 13.6
+ """
+ if self.dispatcher.persistence:
+ if self.dispatcher.persistence.store_bot_data:
+ self.dispatcher.persistence.refresh_bot_data(self.bot_data)
+ if self.dispatcher.persistence.store_chat_data and self._chat_id_and_data is not None:
+ self.dispatcher.persistence.refresh_chat_data(*self._chat_id_and_data)
+ if self.dispatcher.persistence.store_user_data and self._user_id_and_data is not None:
+ self.dispatcher.persistence.refresh_user_data(*self._user_id_and_data)
+
@classmethod
def from_error(
- cls,
+ cls: Type[CC],
update: object,
error: Exception,
dispatcher: 'Dispatcher',
async_args: Union[List, Tuple] = None,
async_kwargs: Dict[str, object] = None,
- ) -> 'CallbackContext':
+ ) -> CC:
"""
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
handlers.
@@ -195,7 +232,7 @@ class CallbackContext:
return self
@classmethod
- def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CallbackContext':
+ def from_update(cls: Type[CC], update: object, dispatcher: 'Dispatcher') -> CC:
"""
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the
handlers.
@@ -217,13 +254,19 @@ class CallbackContext:
user = update.effective_user
if chat:
- self._chat_data = dispatcher.chat_data[chat.id] # pylint: disable=W0212
+ self._chat_id_and_data = (
+ chat.id,
+ dispatcher.chat_data[chat.id], # pylint: disable=W0212
+ )
if user:
- self._user_data = dispatcher.user_data[user.id] # pylint: disable=W0212
+ self._user_id_and_data = (
+ user.id,
+ dispatcher.user_data[user.id], # pylint: disable=W0212
+ )
return self
@classmethod
- def from_job(cls, job: 'Job', dispatcher: 'Dispatcher') -> 'CallbackContext':
+ def from_job(cls: Type[CC], job: 'Job', dispatcher: 'Dispatcher') -> CC:
"""
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to a
job callback.
diff --git a/telegram/ext/callbackqueryhandler.py b/telegram/ext/callbackqueryhandler.py
index 95eb8afb1..452578049 100644
--- a/telegram/ext/callbackqueryhandler.py
+++ b/telegram/ext/callbackqueryhandler.py
@@ -35,14 +35,15 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class CallbackQueryHandler(Handler[Update]):
+class CallbackQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram callback queries. Optionally based on a regex.
Read the documentation of the ``re`` module for more information.
@@ -124,7 +125,7 @@ class CallbackQueryHandler(Handler[Update]):
def __init__(
self,
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False,
pass_job_queue: bool = False,
pattern: Union[str, Pattern] = None,
@@ -191,7 +192,7 @@ class CallbackQueryHandler(Handler[Update]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: Update,
dispatcher: 'Dispatcher',
check_result: Union[bool, Match],
diff --git a/telegram/ext/chatmemberhandler.py b/telegram/ext/chatmemberhandler.py
index c098db904..9499cfd24 100644
--- a/telegram/ext/chatmemberhandler.py
+++ b/telegram/ext/chatmemberhandler.py
@@ -17,19 +17,17 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the ChatMemberHandler classes."""
-from typing import ClassVar, TypeVar, Union, Callable, TYPE_CHECKING
+from typing import ClassVar, TypeVar, Union, Callable
from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
-
-if TYPE_CHECKING:
- from telegram.ext import CallbackContext
+from .utils.types import CCT
RT = TypeVar('RT')
-class ChatMemberHandler(Handler[Update]):
+class ChatMemberHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a chat member update.
.. versionadded:: 13.4
@@ -107,7 +105,7 @@ class ChatMemberHandler(Handler[Update]):
def __init__(
self,
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
chat_member_types: int = MY_CHAT_MEMBER,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
diff --git a/telegram/ext/choseninlineresulthandler.py b/telegram/ext/choseninlineresulthandler.py
index 2ae2814f6..ec3528945 100644
--- a/telegram/ext/choseninlineresulthandler.py
+++ b/telegram/ext/choseninlineresulthandler.py
@@ -24,6 +24,7 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
RT = TypeVar('RT')
@@ -31,7 +32,7 @@ if TYPE_CHECKING:
from telegram.ext import CallbackContext, Dispatcher
-class ChosenInlineResultHandler(Handler[Update]):
+class ChosenInlineResultHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a chosen inline result.
Note:
diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py
index 6d6a1b6e6..1f0a32118 100644
--- a/telegram/ext/commandhandler.py
+++ b/telegram/ext/commandhandler.py
@@ -27,15 +27,16 @@ from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.types import SLT
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
+from .utils.types import CCT
from .handler import Handler
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class CommandHandler(Handler[Update]):
+class CommandHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram commands.
Commands are Telegram messages that start with ``/``, optionally followed by an ``@`` and the
@@ -134,7 +135,7 @@ class CommandHandler(Handler[Update]):
def __init__(
self,
command: SLT[str],
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
filters: BaseFilter = None,
allow_edited: bool = None,
pass_args: bool = False,
@@ -231,7 +232,7 @@ class CommandHandler(Handler[Update]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: Update,
dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Tuple[List[str], Optional[bool]]]],
@@ -359,7 +360,7 @@ class PrefixHandler(CommandHandler):
self,
prefix: SLT[str],
command: SLT[str],
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
filters: BaseFilter = None,
pass_args: bool = False,
pass_update_queue: bool = False,
diff --git a/telegram/ext/contexttypes.py b/telegram/ext/contexttypes.py
new file mode 100644
index 000000000..2156e7f62
--- /dev/null
+++ b/telegram/ext/contexttypes.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python
+#
+# A library that provides a Python interface to the Telegram Bot API
+# Copyright (C) 2020
+# Leandro Toledo de Souza
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser Public License for more details.
+#
+# You should have received a copy of the GNU Lesser Public License
+# along with this program. If not, see [http://www.gnu.org/licenses/].
+# pylint: disable=R0201
+"""This module contains the auxiliary class ContextTypes."""
+from typing import Type, Generic, overload, Dict # pylint: disable=W0611
+
+from telegram.ext.callbackcontext import CallbackContext
+from telegram.ext.utils.types import CCT, UD, CD, BD
+
+
+class ContextTypes(Generic[CCT, UD, CD, BD]):
+ """
+ Convenience class to gather customizable types of the :class:`telegram.ext.CallbackContext`
+ interface.
+
+ .. versionadded:: 13.6
+
+ Args:
+ context (:obj:`type`, optional): Determines the type of the ``context`` argument of all
+ (error-)handler callbacks and job callbacks. Must be a subclass of
+ :class:`telegram.ext.CallbackContext`. Defaults to
+ :class:`telegram.ext.CallbackContext`.
+ bot_data (:obj:`type`, optional): Determines the type of ``context.bot_data`` of all
+ (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
+ instantiating without arguments.
+ chat_data (:obj:`type`, optional): Determines the type of ``context.chat_data`` of all
+ (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
+ instantiating without arguments.
+ user_data (:obj:`type`, optional): Determines the type of ``context.user_data`` of all
+ (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support
+ instantiating without arguments.
+
+ """
+
+ __slots__ = ('_context', '_bot_data', '_chat_data', '_user_data')
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CallbackContext, Dict, Dict, Dict]',
+ ):
+ ...
+
+ @overload
+ def __init__(self: 'ContextTypes[CCT, Dict, Dict, Dict]', context: Type[CCT]):
+ ...
+
+ @overload
+ def __init__(self: 'ContextTypes[CallbackContext, UD, Dict, Dict]', bot_data: Type[UD]):
+ ...
+
+ @overload
+ def __init__(self: 'ContextTypes[CallbackContext, Dict, CD, Dict]', chat_data: Type[CD]):
+ ...
+
+ @overload
+ def __init__(self: 'ContextTypes[CallbackContext, Dict, Dict, BD]', user_data: Type[BD]):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, UD, Dict, Dict]', context: Type[CCT], bot_data: Type[UD]
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, Dict, CD, Dict]', context: Type[CCT], chat_data: Type[CD]
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, Dict, Dict, BD]', context: Type[CCT], user_data: Type[BD]
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CallbackContext, UD, CD, Dict]',
+ bot_data: Type[UD],
+ chat_data: Type[CD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CallbackContext, UD, Dict, BD]',
+ bot_data: Type[UD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CallbackContext, Dict, CD, BD]',
+ chat_data: Type[CD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, UD, CD, Dict]',
+ context: Type[CCT],
+ bot_data: Type[UD],
+ chat_data: Type[CD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, UD, Dict, BD]',
+ context: Type[CCT],
+ bot_data: Type[UD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, Dict, CD, BD]',
+ context: Type[CCT],
+ chat_data: Type[CD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CallbackContext, UD, CD, BD]',
+ bot_data: Type[UD],
+ chat_data: Type[CD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'ContextTypes[CCT, UD, CD, BD]',
+ context: Type[CCT],
+ bot_data: Type[UD],
+ chat_data: Type[CD],
+ user_data: Type[BD],
+ ):
+ ...
+
+ def __init__( # type: ignore[no-untyped-def]
+ self,
+ context=CallbackContext,
+ bot_data=dict,
+ chat_data=dict,
+ user_data=dict,
+ ):
+ if not issubclass(context, CallbackContext):
+ raise ValueError('context must be a subclass of CallbackContext.')
+
+ # We make all those only accessible via properties because we don't currently support
+ # changing this at runtime, so overriding the attributes doesn't make sense
+ self._context = context
+ self._bot_data = bot_data
+ self._chat_data = chat_data
+ self._user_data = user_data
+
+ @property
+ def context(self) -> Type[CCT]:
+ return self._context
+
+ @property
+ def bot_data(self) -> Type[BD]:
+ return self._bot_data
+
+ @property
+ def chat_data(self) -> Type[CD]:
+ return self._chat_data
+
+ @property
+ def user_data(self) -> Type[UD]:
+ return self._user_data
diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py
index ffaaf969c..081e10f95 100644
--- a/telegram/ext/conversationhandler.py
+++ b/telegram/ext/conversationhandler.py
@@ -38,6 +38,7 @@ from telegram.ext import (
)
from telegram.ext.utils.promise import Promise
from telegram.utils.types import ConversationDict
+from telegram.ext.utils.types import CCT
if TYPE_CHECKING:
from telegram.ext import Dispatcher, Job
@@ -61,7 +62,7 @@ class _ConversationTimeoutContext:
self.callback_context = callback_context
-class ConversationHandler(Handler[Update]):
+class ConversationHandler(Handler[Update, CCT]):
"""
A handler to hold a conversation with a single or multiple users through Telegram updates by
managing four collections of other handlers.
@@ -215,9 +216,9 @@ class ConversationHandler(Handler[Update]):
# pylint: disable=W0231
def __init__(
self,
- entry_points: List[Handler],
- states: Dict[object, List[Handler]],
- fallbacks: List[Handler],
+ entry_points: List[Handler[Update, CCT]],
+ states: Dict[object, List[Handler[Update, CCT]]],
+ fallbacks: List[Handler[Update, CCT]],
allow_reentry: bool = False,
per_chat: bool = True,
per_user: bool = True,
diff --git a/telegram/ext/dictpersistence.py b/telegram/ext/dictpersistence.py
index e5df61605..ad9360442 100644
--- a/telegram/ext/dictpersistence.py
+++ b/telegram/ext/dictpersistence.py
@@ -17,7 +17,6 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the DictPersistence class."""
-from copy import deepcopy
from typing import DefaultDict, Dict, Optional, Tuple
from collections import defaultdict
@@ -37,7 +36,7 @@ except ImportError:
class DictPersistence(BasePersistence):
- """Using python's dicts and json for making your bot persistent.
+ """Using Python's :obj:`dict` and ``json`` for making your bot persistent.
Note:
This class does *not* implement a :meth:`flush` method, meaning that data managed by
@@ -202,7 +201,7 @@ class DictPersistence(BasePersistence):
pass
else:
self._user_data = defaultdict(dict)
- return deepcopy(self.user_data) # type: ignore[arg-type]
+ return self.user_data # type: ignore[return-value]
def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]:
"""Returns the chat_data created from the ``chat_data_json`` or an empty
@@ -215,7 +214,7 @@ class DictPersistence(BasePersistence):
pass
else:
self._chat_data = defaultdict(dict)
- return deepcopy(self.chat_data) # type: ignore[arg-type]
+ return self.chat_data # type: ignore[return-value]
def get_bot_data(self) -> Dict[object, object]:
"""Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`.
@@ -227,7 +226,7 @@ class DictPersistence(BasePersistence):
pass
else:
self._bot_data = {}
- return deepcopy(self.bot_data) # type: ignore[arg-type]
+ return self.bot_data # type: ignore[return-value]
def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations created from the ``conversations_json`` or an empty
@@ -264,7 +263,7 @@ class DictPersistence(BasePersistence):
Args:
user_id (:obj:`int`): The user the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
+ data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
"""
if self._user_data is None:
self._user_data = defaultdict(dict)
@@ -278,7 +277,7 @@ class DictPersistence(BasePersistence):
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
+ data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
"""
if self._chat_data is None:
self._chat_data = defaultdict(dict)
@@ -295,5 +294,26 @@ class DictPersistence(BasePersistence):
"""
if self._bot_data == data:
return
- self._bot_data = data.copy()
+ self._bot_data = data
self._bot_data_json = None
+
+ def refresh_user_data(self, user_id: int, user_data: Dict) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data`
+ """
+
+ def refresh_chat_data(self, chat_id: int, chat_data: Dict) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data`
+ """
+
+ def refresh_bot_data(self, bot_data: Dict) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data`
+ """
diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py
index 4dbd2382d..db5f0958a 100644
--- a/telegram/ext/dispatcher.py
+++ b/telegram/ext/dispatcher.py
@@ -26,16 +26,30 @@ from functools import wraps
from queue import Empty, Queue
from threading import BoundedSemaphore, Event, Lock, Thread, current_thread
from time import sleep
-from typing import TYPE_CHECKING, Callable, DefaultDict, Dict, List, Optional, Union, Set
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Union,
+ Generic,
+ TypeVar,
+ overload,
+ cast,
+ DefaultDict,
+)
from uuid import uuid4
from telegram import TelegramError, Update
-from telegram.ext import BasePersistence
+from telegram.ext import BasePersistence, ContextTypes
from telegram.ext.callbackcontext import CallbackContext
from telegram.ext.handler import Handler
from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated
from telegram.ext.utils.promise import Promise
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
+from telegram.ext.utils.types import CCT, UD, CD, BD
if TYPE_CHECKING:
from telegram import Bot
@@ -43,6 +57,8 @@ if TYPE_CHECKING:
DEFAULT_GROUP: int = 0
+UT = TypeVar('UT')
+
def run_async(
func: Callable[[Update, CallbackContext], object]
@@ -105,7 +121,7 @@ class DispatcherHandlerStop(Exception):
self.state = state
-class Dispatcher:
+class Dispatcher(Generic[CCT, UD, CD, BD]):
"""This class dispatches all kinds of updates to its registered handlers.
Args:
@@ -120,6 +136,12 @@ class Dispatcher:
use_context (:obj:`bool`, optional): If set to :obj:`True` uses the context based callback
API (ignored if `dispatcher` argument is used). Defaults to :obj:`True`.
**New users**: set this to :obj:`True`.
+ context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
+ of :class:`telegram.ext.ContextTypes` to customize the types used in the
+ ``context`` interface. If not passed, the defaults documented in
+ :class:`telegram.ext.ContextTypes` will be used.
+
+ .. versionadded:: 13.6
Attributes:
bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers.
@@ -133,6 +155,10 @@ class Dispatcher:
bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts.
+ context_types (:class:`telegram.ext.ContextTypes`): Container for the types used
+ in the ``context`` interface.
+
+ .. versionadded:: 13.6
"""
@@ -158,6 +184,7 @@ class Dispatcher:
'bot',
'__dict__',
'__weakref__',
+ 'context_types',
)
__singleton_lock = Lock()
@@ -165,6 +192,33 @@ class Dispatcher:
__singleton = None
logger = logging.getLogger(__name__)
+ @overload
+ def __init__(
+ self: 'Dispatcher[CallbackContext[Dict, Dict, Dict], Dict, Dict, Dict]',
+ bot: 'Bot',
+ update_queue: Queue,
+ workers: int = 4,
+ exception_event: Event = None,
+ job_queue: 'JobQueue' = None,
+ persistence: BasePersistence = None,
+ use_context: bool = True,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'Dispatcher[CCT, UD, CD, BD]',
+ bot: 'Bot',
+ update_queue: Queue,
+ workers: int = 4,
+ exception_event: Event = None,
+ job_queue: 'JobQueue' = None,
+ persistence: BasePersistence = None,
+ use_context: bool = True,
+ context_types: ContextTypes[CCT, UD, CD, BD] = None,
+ ):
+ ...
+
def __init__(
self,
bot: 'Bot',
@@ -174,12 +228,14 @@ class Dispatcher:
job_queue: 'JobQueue' = None,
persistence: BasePersistence = None,
use_context: bool = True,
+ context_types: ContextTypes[CCT, UD, CD, BD] = None,
):
self.bot = bot
self.update_queue = update_queue
self.job_queue = job_queue
self.workers = workers
self.use_context = use_context
+ self.context_types = cast(ContextTypes[CCT, UD, CD, BD], context_types or ContextTypes())
if not use_context:
warnings.warn(
@@ -193,9 +249,9 @@ class Dispatcher:
'Asynchronous callbacks can not be processed without at least one worker thread.'
)
- self.user_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict)
- self.chat_data: DefaultDict[int, Dict[object, object]] = defaultdict(dict)
- self.bot_data = {}
+ self.user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data)
+ self.chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data)
+ self.bot_data = self.context_types.bot_data()
self.persistence: Optional[BasePersistence] = None
self._update_persistence_lock = Lock()
if persistence:
@@ -213,8 +269,11 @@ class Dispatcher:
raise ValueError("chat_data must be of type defaultdict")
if self.persistence.store_bot_data:
self.bot_data = self.persistence.get_bot_data()
- if not isinstance(self.bot_data, dict):
- raise ValueError("bot_data must be of type dict")
+ if not isinstance(self.bot_data, self.context_types.bot_data):
+ raise ValueError(
+ f"bot_data must be of type {self.context_types.bot_data.__name__}"
+ )
+
else:
self.persistence = None
@@ -477,7 +536,8 @@ class Dispatcher:
check = handler.check_update(update)
if check is not None and check is not False:
if not context and self.use_context:
- context = CallbackContext.from_update(update, self)
+ context = self.context_types.context.from_update(update, self)
+ context.refresh_data()
handled = True
sync_modes.append(handler.run_async)
handler.handle_update(update, self, check, context)
@@ -510,7 +570,7 @@ class Dispatcher:
if not handled_only_async:
self.update_persistence(update=update)
- def add_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None:
+ def add_handler(self, handler: Handler[UT, CCT], group: int = DEFAULT_GROUP) -> None:
"""Register a handler.
TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of
@@ -542,14 +602,22 @@ class Dispatcher:
raise TypeError(f'handler is not an instance of {Handler.__name__}')
if not isinstance(group, int):
raise TypeError('group is not int')
- if isinstance(handler, ConversationHandler) and handler.persistent and handler.name:
+ # For some reason MyPy infers the type of handler is here,
+ # so for now we just ignore all the errors
+ if (
+ isinstance(handler, ConversationHandler)
+ and handler.persistent # type: ignore[attr-defined]
+ and handler.name # type: ignore[attr-defined]
+ ):
if not self.persistence:
raise ValueError(
- f"ConversationHandler {handler.name} can not be persistent if dispatcher has "
- f"no persistence"
+ f"ConversationHandler {handler.name} " # type: ignore[attr-defined]
+ f"can not be persistent if dispatcher has no persistence"
)
- handler.persistence = self.persistence
- handler.conversations = self.persistence.get_conversations(handler.name)
+ handler.persistence = self.persistence # type: ignore[attr-defined]
+ handler.conversations = ( # type: ignore[attr-defined]
+ self.persistence.get_conversations(handler.name) # type: ignore[attr-defined]
+ )
if group not in self.handlers:
self.handlers[group] = []
@@ -643,7 +711,7 @@ class Dispatcher:
def add_error_handler(
self,
- callback: Callable[[object, CallbackContext], None],
+ callback: Callable[[object, CCT], None],
run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, # pylint: disable=W0621
) -> None:
"""Registers an error handler in the Dispatcher. This handler will receive every error
@@ -678,7 +746,7 @@ class Dispatcher:
self.error_handlers[callback] = run_async
- def remove_error_handler(self, callback: Callable[[object, CallbackContext], None]) -> None:
+ def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None:
"""Removes an error handler.
Args:
@@ -705,7 +773,7 @@ class Dispatcher:
if self.error_handlers:
for callback, run_async in self.error_handlers.items(): # pylint: disable=W0621
if self.use_context:
- context = CallbackContext.from_error(
+ context = self.context_types.context.from_error(
update, error, self, async_args=async_args, async_kwargs=async_kwargs
)
if run_async:
diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py
index 17a86d166..befaf4139 100644
--- a/telegram/ext/handler.py
+++ b/telegram/ext/handler.py
@@ -26,15 +26,16 @@ from telegram.utils.deprecate import set_new_attribute_deprecated
from telegram import Update
from telegram.ext.utils.promise import Promise
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
+from telegram.ext.utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
UT = TypeVar('UT')
-class Handler(Generic[UT], ABC):
+class Handler(Generic[UT, CCT], ABC):
"""The base class for all update handlers. Create custom handlers by inheriting from it.
Note:
@@ -115,7 +116,7 @@ class Handler(Generic[UT], ABC):
def __init__(
self,
- callback: Callable[[UT, 'CallbackContext'], RT],
+ callback: Callable[[UT, CCT], RT],
pass_update_queue: bool = False,
pass_job_queue: bool = False,
pass_user_data: bool = False,
@@ -165,7 +166,7 @@ class Handler(Generic[UT], ABC):
update: UT,
dispatcher: 'Dispatcher',
check_result: object,
- context: 'CallbackContext' = None,
+ context: CCT = None,
) -> Union[RT, Promise]:
"""
This method is called if it was determined that an update should indeed
@@ -205,7 +206,7 @@ class Handler(Generic[UT], ABC):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: UT,
dispatcher: 'Dispatcher',
check_result: Any,
diff --git a/telegram/ext/inlinequeryhandler.py b/telegram/ext/inlinequeryhandler.py
index 3216345c3..11103e71f 100644
--- a/telegram/ext/inlinequeryhandler.py
+++ b/telegram/ext/inlinequeryhandler.py
@@ -35,14 +35,15 @@ from telegram import Update
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class InlineQueryHandler(Handler[Update]):
+class InlineQueryHandler(Handler[Update, CCT]):
"""
Handler class to handle Telegram inline queries. Optionally based on a regex. Read the
documentation of the ``re`` module for more information.
@@ -133,7 +134,7 @@ class InlineQueryHandler(Handler[Update]):
def __init__(
self,
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False,
pass_job_queue: bool = False,
pattern: Union[str, Pattern] = None,
@@ -207,7 +208,7 @@ class InlineQueryHandler(Handler[Update]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: Update,
dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Match]],
diff --git a/telegram/ext/jobqueue.py b/telegram/ext/jobqueue.py
index 2ea2fec5c..837cac561 100644
--- a/telegram/ext/jobqueue.py
+++ b/telegram/ext/jobqueue.py
@@ -72,7 +72,7 @@ class JobQueue:
def _build_args(self, job: 'Job') -> List[Union[CallbackContext, 'Bot', 'Job']]:
if self._dispatcher.use_context:
- return [CallbackContext.from_job(job, self._dispatcher)]
+ return [self._dispatcher.context_types.context.from_job(job, self._dispatcher)]
return [self._dispatcher.bot, job]
def _tz_now(self) -> datetime.datetime:
@@ -585,7 +585,7 @@ class Job:
"""Executes the callback function independently of the jobs schedule."""
try:
if dispatcher.use_context:
- self.callback(CallbackContext.from_job(self, dispatcher))
+ self.callback(dispatcher.context_types.context.from_job(self, dispatcher))
else:
self.callback(dispatcher.bot, self) # type: ignore[arg-type,call-arg]
except Exception as exc:
diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py
index ed1078124..c3f0c015c 100644
--- a/telegram/ext/messagehandler.py
+++ b/telegram/ext/messagehandler.py
@@ -27,14 +27,15 @@ from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class MessageHandler(Handler[Update]):
+class MessageHandler(Handler[Update, CCT]):
"""Handler class to handle telegram messages. They might contain text, media or status updates.
Note:
@@ -125,7 +126,7 @@ class MessageHandler(Handler[Update]):
def __init__(
self,
filters: BaseFilter,
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
pass_update_queue: bool = False,
pass_job_queue: bool = False,
pass_user_data: bool = False,
@@ -197,7 +198,7 @@ class MessageHandler(Handler[Update]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: Update,
dispatcher: 'Dispatcher',
check_result: Optional[Union[bool, Dict[str, object]]],
diff --git a/telegram/ext/picklepersistence.py b/telegram/ext/picklepersistence.py
index 3127d3baf..d015924b7 100644
--- a/telegram/ext/picklepersistence.py
+++ b/telegram/ext/picklepersistence.py
@@ -19,14 +19,23 @@
"""This module contains the PicklePersistence class."""
import pickle
from collections import defaultdict
-from copy import deepcopy
-from typing import Any, DefaultDict, Dict, Optional, Tuple
+from typing import (
+ Any,
+ Dict,
+ Optional,
+ Tuple,
+ overload,
+ cast,
+ DefaultDict,
+)
from telegram.ext import BasePersistence
-from telegram.utils.types import ConversationDict
+from telegram.utils.types import ConversationDict # pylint: disable=W0611
+from .utils.types import UD, CD, BD
+from .contexttypes import ContextTypes
-class PicklePersistence(BasePersistence):
+class PicklePersistence(BasePersistence[UD, CD, BD]):
"""Using python's builtin pickle for making you bot persistent.
Warning:
@@ -54,6 +63,12 @@ class PicklePersistence(BasePersistence):
:meth:`flush` is called and keep data in memory until that happens. When
:obj:`False` will store data on any transaction *and* on call to :meth:`flush`.
Default is :obj:`False`.
+ context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
+ of :class:`telegram.ext.ContextTypes` to customize the types used in the
+ ``context`` interface. If not passed, the defaults documented in
+ :class:`telegram.ext.ContextTypes` will be used.
+
+ .. versionadded:: 13.6
Attributes:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
@@ -71,6 +86,10 @@ class PicklePersistence(BasePersistence):
:meth:`flush` is called and keep data in memory until that happens. When
:obj:`False` will store data on any transaction *and* on call to :meth:`flush`.
Default is :obj:`False`.
+ context_types (:class:`telegram.ext.ContextTypes`): Container for the types used
+ in the ``context`` interface.
+
+ .. versionadded:: 13.6
"""
__slots__ = (
@@ -81,8 +100,34 @@ class PicklePersistence(BasePersistence):
'chat_data',
'bot_data',
'conversations',
+ 'context_types',
)
+ @overload
+ def __init__(
+ self: 'PicklePersistence[Dict, Dict, Dict]',
+ filename: str,
+ store_user_data: bool = True,
+ store_chat_data: bool = True,
+ store_bot_data: bool = True,
+ single_file: bool = True,
+ on_flush: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'PicklePersistence[UD, CD, BD]',
+ filename: str,
+ store_user_data: bool = True,
+ store_chat_data: bool = True,
+ store_bot_data: bool = True,
+ single_file: bool = True,
+ on_flush: bool = False,
+ context_types: ContextTypes[Any, UD, CD, BD] = None,
+ ):
+ ...
+
def __init__(
self,
filename: str,
@@ -91,6 +136,7 @@ class PicklePersistence(BasePersistence):
store_bot_data: bool = True,
single_file: bool = True,
on_flush: bool = False,
+ context_types: ContextTypes[Any, UD, CD, BD] = None,
):
super().__init__(
store_user_data=store_user_data,
@@ -100,26 +146,27 @@ class PicklePersistence(BasePersistence):
self.filename = filename
self.single_file = single_file
self.on_flush = on_flush
- self.user_data: Optional[DefaultDict[int, Dict]] = None
- self.chat_data: Optional[DefaultDict[int, Dict]] = None
- self.bot_data: Optional[Dict] = None
+ self.user_data: Optional[DefaultDict[int, UD]] = None
+ self.chat_data: Optional[DefaultDict[int, CD]] = None
+ self.bot_data: Optional[BD] = None
self.conversations: Optional[Dict[str, Dict[Tuple, object]]] = None
+ self.context_types = cast(ContextTypes[Any, UD, CD, BD], context_types or ContextTypes())
def _load_singlefile(self) -> None:
try:
filename = self.filename
with open(self.filename, "rb") as file:
data = pickle.load(file)
- self.user_data = defaultdict(dict, data['user_data'])
- self.chat_data = defaultdict(dict, data['chat_data'])
+ self.user_data = defaultdict(self.context_types.user_data, data['user_data'])
+ self.chat_data = defaultdict(self.context_types.chat_data, data['chat_data'])
# For backwards compatibility with files not containing bot data
- self.bot_data = data.get('bot_data', {})
+ self.bot_data = data.get('bot_data', self.context_types.bot_data())
self.conversations = data['conversations']
except OSError:
self.conversations = {}
- self.user_data = defaultdict(dict)
- self.chat_data = defaultdict(dict)
- self.bot_data = {}
+ self.user_data = defaultdict(self.context_types.user_data)
+ self.chat_data = defaultdict(self.context_types.chat_data)
+ self.bot_data = self.context_types.bot_data()
except pickle.UnpicklingError as exc:
raise TypeError(f"File {filename} does not contain valid pickle data") from exc
except Exception as exc:
@@ -152,11 +199,11 @@ class PicklePersistence(BasePersistence):
with open(filename, "wb") as file:
pickle.dump(data, file)
- def get_user_data(self) -> DefaultDict[int, Dict[object, object]]:
+ def get_user_data(self) -> DefaultDict[int, UD]:
"""Returns the user_data from the pickle file if it exists or an empty :obj:`defaultdict`.
Returns:
- :obj:`defaultdict`: The restored user data.
+ DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.UD`]: The restored user data.
"""
if self.user_data:
pass
@@ -164,19 +211,19 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_user_data"
data = self._load_file(filename)
if not data:
- data = defaultdict(dict)
+ data = defaultdict(self.context_types.user_data)
else:
- data = defaultdict(dict, data)
+ data = defaultdict(self.context_types.user_data, data)
self.user_data = data
else:
self._load_singlefile()
- return deepcopy(self.user_data) # type: ignore[arg-type]
+ return self.user_data # type: ignore[return-value]
- def get_chat_data(self) -> DefaultDict[int, Dict[object, object]]:
+ def get_chat_data(self) -> DefaultDict[int, CD]:
"""Returns the chat_data from the pickle file if it exists or an empty :obj:`defaultdict`.
Returns:
- :obj:`defaultdict`: The restored chat data.
+ DefaultDict[:obj:`int`, :class:`telegram.ext.utils.types.CD`]: The restored chat data.
"""
if self.chat_data:
pass
@@ -184,19 +231,20 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_chat_data"
data = self._load_file(filename)
if not data:
- data = defaultdict(dict)
+ data = defaultdict(self.context_types.chat_data)
else:
- data = defaultdict(dict, data)
+ data = defaultdict(self.context_types.chat_data, data)
self.chat_data = data
else:
self._load_singlefile()
- return deepcopy(self.chat_data) # type: ignore[arg-type]
+ return self.chat_data # type: ignore[return-value]
- def get_bot_data(self) -> Dict[object, object]:
- """Returns the bot_data from the pickle file if it exists or an empty :obj:`dict`.
+ def get_bot_data(self) -> BD:
+ """Returns the bot_data from the pickle file if it exists or an empty object of type
+ :class:`telegram.ext.utils.types.BD`.
Returns:
- :obj:`dict`: The restored bot data.
+ :class:`telegram.ext.utils.types.BD`: The restored bot data.
"""
if self.bot_data:
pass
@@ -204,11 +252,11 @@ class PicklePersistence(BasePersistence):
filename = f"{self.filename}_bot_data"
data = self._load_file(filename)
if not data:
- data = {}
+ data = self.context_types.bot_data()
self.bot_data = data
else:
self._load_singlefile()
- return deepcopy(self.bot_data) # type: ignore[arg-type]
+ return self.bot_data # type: ignore[return-value]
def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations from the pickle file if it exsists or an empty dict.
@@ -254,15 +302,16 @@ class PicklePersistence(BasePersistence):
else:
self._dump_singlefile()
- def update_user_data(self, user_id: int, data: Dict) -> None:
+ def update_user_data(self, user_id: int, data: UD) -> None:
"""Will update the user_data and depending on :attr:`on_flush` save the pickle file.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
+ data (:class:`telegram.ext.utils.types.UD`): The
+ :attr:`telegram.ext.dispatcher.user_data` ``[user_id]``.
"""
if self.user_data is None:
- self.user_data = defaultdict(dict)
+ self.user_data = defaultdict(self.context_types.user_data)
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
@@ -273,15 +322,16 @@ class PicklePersistence(BasePersistence):
else:
self._dump_singlefile()
- def update_chat_data(self, chat_id: int, data: Dict) -> None:
+ def update_chat_data(self, chat_id: int, data: CD) -> None:
"""Will update the chat_data and depending on :attr:`on_flush` save the pickle file.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
+ data (:class:`telegram.ext.utils.types.CD`): The
+ :attr:`telegram.ext.dispatcher.chat_data` ``[chat_id]``.
"""
if self.chat_data is None:
- self.chat_data = defaultdict(dict)
+ self.chat_data = defaultdict(self.context_types.chat_data)
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data
@@ -292,15 +342,16 @@ class PicklePersistence(BasePersistence):
else:
self._dump_singlefile()
- def update_bot_data(self, data: Dict) -> None:
+ def update_bot_data(self, data: BD) -> None:
"""Will update the bot_data and depending on :attr:`on_flush` save the pickle file.
Args:
- data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data`.
+ data (:class:`telegram.ext.utils.types.BD`): The
+ :attr:`telegram.ext.dispatcher.bot_data`.
"""
if self.bot_data == data:
return
- self.bot_data = data.copy()
+ self.bot_data = data
if not self.on_flush:
if not self.single_file:
filename = f"{self.filename}_bot_data"
@@ -308,6 +359,27 @@ class PicklePersistence(BasePersistence):
else:
self._dump_singlefile()
+ def refresh_user_data(self, user_id: int, user_data: UD) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data`
+ """
+
+ def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data`
+ """
+
+ def refresh_bot_data(self, bot_data: BD) -> None:
+ """Does nothing.
+
+ .. versionadded:: 13.6
+ .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data`
+ """
+
def flush(self) -> None:
"""Will save all data in memory to pickle file(s)."""
if self.single_file:
diff --git a/telegram/ext/pollanswerhandler.py b/telegram/ext/pollanswerhandler.py
index 73cd36bce..199bcb3ad 100644
--- a/telegram/ext/pollanswerhandler.py
+++ b/telegram/ext/pollanswerhandler.py
@@ -22,9 +22,10 @@
from telegram import Update
from .handler import Handler
+from .utils.types import CCT
-class PollAnswerHandler(Handler[Update]):
+class PollAnswerHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a poll answer.
Note:
diff --git a/telegram/ext/pollhandler.py b/telegram/ext/pollhandler.py
index c0719d5e5..7b67e76ff 100644
--- a/telegram/ext/pollhandler.py
+++ b/telegram/ext/pollhandler.py
@@ -22,9 +22,10 @@
from telegram import Update
from .handler import Handler
+from .utils.types import CCT
-class PollHandler(Handler[Update]):
+class PollHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram updates that contain a poll.
Note:
diff --git a/telegram/ext/precheckoutqueryhandler.py b/telegram/ext/precheckoutqueryhandler.py
index 1a93cccca..3a2eee30d 100644
--- a/telegram/ext/precheckoutqueryhandler.py
+++ b/telegram/ext/precheckoutqueryhandler.py
@@ -22,9 +22,10 @@
from telegram import Update
from .handler import Handler
+from .utils.types import CCT
-class PreCheckoutQueryHandler(Handler[Update]):
+class PreCheckoutQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram PreCheckout callback queries.
Note:
diff --git a/telegram/ext/regexhandler.py b/telegram/ext/regexhandler.py
index 50b82a0e8..399e4df7d 100644
--- a/telegram/ext/regexhandler.py
+++ b/telegram/ext/regexhandler.py
@@ -26,9 +26,10 @@ from telegram import Update
from telegram.ext import Filters, MessageHandler
from telegram.utils.deprecate import TelegramDeprecationWarning
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
+from telegram.ext.utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
@@ -113,7 +114,7 @@ class RegexHandler(MessageHandler):
def __init__(
self,
pattern: Union[str, Pattern],
- callback: Callable[[Update, 'CallbackContext'], RT],
+ callback: Callable[[Update, CCT], RT],
pass_groups: bool = False,
pass_groupdict: bool = False,
pass_update_queue: bool = False,
diff --git a/telegram/ext/shippingqueryhandler.py b/telegram/ext/shippingqueryhandler.py
index 53ea28494..e4229ceb7 100644
--- a/telegram/ext/shippingqueryhandler.py
+++ b/telegram/ext/shippingqueryhandler.py
@@ -21,9 +21,10 @@
from telegram import Update
from .handler import Handler
+from .utils.types import CCT
-class ShippingQueryHandler(Handler[Update]):
+class ShippingQueryHandler(Handler[Update, CCT]):
"""Handler class to handle Telegram shipping callback queries.
Note:
diff --git a/telegram/ext/stringcommandhandler.py b/telegram/ext/stringcommandhandler.py
index 98aa6518b..1d84892e4 100644
--- a/telegram/ext/stringcommandhandler.py
+++ b/telegram/ext/stringcommandhandler.py
@@ -23,14 +23,15 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, Union
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class StringCommandHandler(Handler[str]):
+class StringCommandHandler(Handler[str, CCT]):
"""Handler class to handle string commands. Commands are string updates that start with ``/``.
The handler will add a ``list`` to the
:class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings,
@@ -90,7 +91,7 @@ class StringCommandHandler(Handler[str]):
def __init__(
self,
command: str,
- callback: Callable[[str, 'CallbackContext'], RT],
+ callback: Callable[[str, CCT], RT],
pass_args: bool = False,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
@@ -137,7 +138,7 @@ class StringCommandHandler(Handler[str]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: str,
dispatcher: 'Dispatcher',
check_result: Optional[List[str]],
diff --git a/telegram/ext/stringregexhandler.py b/telegram/ext/stringregexhandler.py
index db3ee8440..282c48ad7 100644
--- a/telegram/ext/stringregexhandler.py
+++ b/telegram/ext/stringregexhandler.py
@@ -24,14 +24,15 @@ from typing import TYPE_CHECKING, Callable, Dict, Match, Optional, Pattern, Type
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
+from .utils.types import CCT
if TYPE_CHECKING:
- from telegram.ext import CallbackContext, Dispatcher
+ from telegram.ext import Dispatcher
RT = TypeVar('RT')
-class StringRegexHandler(Handler[str]):
+class StringRegexHandler(Handler[str, CCT]):
"""Handler class to handle string updates based on a regex which checks the update content.
Read the documentation of the ``re`` module for more information. The ``re.match`` function is
@@ -96,7 +97,7 @@ class StringRegexHandler(Handler[str]):
def __init__(
self,
pattern: Union[str, Pattern],
- callback: Callable[[str, 'CallbackContext'], RT],
+ callback: Callable[[str, CCT], RT],
pass_groups: bool = False,
pass_groupdict: bool = False,
pass_update_queue: bool = False,
@@ -153,7 +154,7 @@ class StringRegexHandler(Handler[str]):
def collect_additional_context(
self,
- context: 'CallbackContext',
+ context: CCT,
update: str,
dispatcher: 'Dispatcher',
check_result: Optional[Match],
diff --git a/telegram/ext/typehandler.py b/telegram/ext/typehandler.py
index a4f7d7319..531d10c30 100644
--- a/telegram/ext/typehandler.py
+++ b/telegram/ext/typehandler.py
@@ -18,19 +18,17 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the TypeHandler class."""
-from typing import TYPE_CHECKING, Callable, Type, TypeVar, Union
+from typing import Callable, Type, TypeVar, Union
from telegram.utils.helpers import DefaultValue, DEFAULT_FALSE
from .handler import Handler
-
-if TYPE_CHECKING:
- from telegram.ext import CallbackContext
+from .utils.types import CCT
RT = TypeVar('RT')
UT = TypeVar('UT')
-class TypeHandler(Handler[UT]):
+class TypeHandler(Handler[UT, CCT]):
"""Handler class to handle updates of custom types.
Warning:
@@ -80,7 +78,7 @@ class TypeHandler(Handler[UT]):
def __init__(
self,
type: Type[UT], # pylint: disable=W0622
- callback: Callable[[UT, 'CallbackContext'], RT],
+ callback: Callable[[UT, CCT], RT],
strict: bool = False,
pass_update_queue: bool = False,
pass_job_queue: bool = False,
diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py
index e9975d86f..30bb0f889 100644
--- a/telegram/ext/updater.py
+++ b/telegram/ext/updater.py
@@ -25,21 +25,34 @@ from queue import Queue
from signal import SIGABRT, SIGINT, SIGTERM, signal
from threading import Event, Lock, Thread, current_thread
from time import sleep
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, no_type_check
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ no_type_check,
+ Generic,
+ overload,
+)
from telegram import Bot, TelegramError
from telegram.error import InvalidToken, RetryAfter, TimedOut, Unauthorized
-from telegram.ext import Dispatcher, JobQueue
+from telegram.ext import Dispatcher, JobQueue, ContextTypes
from telegram.utils.deprecate import TelegramDeprecationWarning, set_new_attribute_deprecated
from telegram.utils.helpers import get_signal_name
from telegram.utils.request import Request
+from telegram.ext.utils.types import CCT, UD, CD, BD
from telegram.ext.utils.webhookhandler import WebhookAppClass, WebhookServer
if TYPE_CHECKING:
- from telegram.ext import BasePersistence, Defaults
+ from telegram.ext import BasePersistence, Defaults, CallbackContext
-class Updater:
+class Updater(Generic[CCT, UD, CD, BD]):
"""
This class, which employs the :class:`telegram.ext.Dispatcher`, provides a frontend to
:class:`telegram.Bot` to the programmer, so they can focus on coding the bot. Its purpose is to
@@ -85,6 +98,12 @@ class Updater:
used).
defaults (:class:`telegram.ext.Defaults`, optional): An object containing default values to
be used if not set explicitly in the bot methods.
+ context_types (:class:`telegram.ext.ContextTypes`, optional): Pass an instance
+ of :class:`telegram.ext.ContextTypes` to customize the types used in the
+ ``context`` interface. If not passed, the defaults documented in
+ :class:`telegram.ext.ContextTypes` will be used.
+
+ .. versionadded:: 13.6
Raises:
ValueError: If both :attr:`token` and :attr:`bot` are passed or none of them.
@@ -124,7 +143,52 @@ class Updater:
'__dict__',
)
+ @overload
def __init__(
+ self: 'Updater[CallbackContext, dict, dict, dict]',
+ token: str = None,
+ base_url: str = None,
+ workers: int = 4,
+ bot: Bot = None,
+ private_key: bytes = None,
+ private_key_password: bytes = None,
+ user_sig_handler: Callable = None,
+ request_kwargs: Dict[str, Any] = None,
+ persistence: 'BasePersistence' = None, # pylint: disable=E0601
+ defaults: 'Defaults' = None,
+ use_context: bool = True,
+ base_file_url: str = None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'Updater[CCT, UD, CD, BD]',
+ token: str = None,
+ base_url: str = None,
+ workers: int = 4,
+ bot: Bot = None,
+ private_key: bytes = None,
+ private_key_password: bytes = None,
+ user_sig_handler: Callable = None,
+ request_kwargs: Dict[str, Any] = None,
+ persistence: 'BasePersistence' = None,
+ defaults: 'Defaults' = None,
+ use_context: bool = True,
+ base_file_url: str = None,
+ context_types: ContextTypes[CCT, UD, CD, BD] = None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: 'Updater[CCT, UD, CD, BD]',
+ user_sig_handler: Callable = None,
+ dispatcher: Dispatcher[CCT, UD, CD, BD] = None,
+ ):
+ ...
+
+ def __init__( # type: ignore[no-untyped-def,misc]
self,
token: str = None,
base_url: str = None,
@@ -137,8 +201,9 @@ class Updater:
persistence: 'BasePersistence' = None,
defaults: 'Defaults' = None,
use_context: bool = True,
- dispatcher: Dispatcher = None,
+ dispatcher=None,
base_file_url: str = None,
+ context_types: ContextTypes[CCT, UD, CD, BD] = None,
):
if defaults and bot:
@@ -161,10 +226,12 @@ class Updater:
raise ValueError('`dispatcher` and `bot` are mutually exclusive')
if persistence is not None:
raise ValueError('`dispatcher` and `persistence` are mutually exclusive')
- if workers is not None:
- raise ValueError('`dispatcher` and `workers` are mutually exclusive')
if use_context != dispatcher.use_context:
raise ValueError('`dispatcher` and `use_context` are mutually exclusive')
+ if context_types is not None:
+ raise ValueError('`dispatcher` and `context_types` are mutually exclusive')
+ if workers is not None:
+ raise ValueError('`dispatcher` and `workers` are mutually exclusive')
self.logger = logging.getLogger(__name__)
self._request = None
@@ -212,6 +279,7 @@ class Updater:
exception_event=self.__exception_event,
persistence=persistence,
use_context=use_context,
+ context_types=context_types,
)
self.job_queue.set_dispatcher(self.dispatcher)
else:
diff --git a/telegram/ext/utils/types.py b/telegram/ext/utils/types.py
new file mode 100644
index 000000000..fbaedd165
--- /dev/null
+++ b/telegram/ext/utils/types.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+#
+# A library that provides a Python interface to the Telegram Bot API
+# Copyright (C) 2021
+# Leandro Toledo de Souza
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser Public License for more details.
+#
+# You should have received a copy of the GNU Lesser Public License
+# along with this program. If not, see [http://www.gnu.org/licenses/].
+"""This module contains custom typing aliases.
+
+.. versionadded:: 13.6
+"""
+from typing import TypeVar, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from telegram.ext import CallbackContext # noqa: F401
+
+CCT = TypeVar('CCT', bound='CallbackContext')
+"""An instance of :class:`telegram.ext.CallbackContext` or a custom subclass.
+
+.. versionadded:: 13.6
+"""
+UD = TypeVar('UD')
+"""Type of the user data for a single user.
+
+.. versionadded:: 13.6
+"""
+CD = TypeVar('CD')
+"""Type of the chat data for a single user.
+
+.. versionadded:: 13.6
+"""
+BD = TypeVar('BD')
+"""Type of the bot data.
+
+.. versionadded:: 13.6
+"""
diff --git a/telegram/utils/types.py b/telegram/utils/types.py
index 1ab5f4df2..1ffcb2e44 100644
--- a/telegram/utils/types.py
+++ b/telegram/utils/types.py
@@ -18,11 +18,21 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains custom typing aliases."""
from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar, Union
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
if TYPE_CHECKING:
- from telegram import InputFile
- from telegram.utils.helpers import DefaultValue
+ from telegram import InputFile # noqa: F401
+ from telegram.utils.helpers import DefaultValue # noqa: F401
FileLike = Union[IO, 'InputFile']
"""Either an open file handler or a :class:`telegram.InputFile`."""
diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py
index cb9bc8034..ad4cfb387 100644
--- a/tests/test_callbackcontext.py
+++ b/tests/test_callbackcontext.py
@@ -22,6 +22,10 @@ import pytest
from telegram import Update, Message, Chat, User, TelegramError
from telegram.ext import CallbackContext
+"""
+CallbackContext.refresh_data is tested in TestBasePersistence
+"""
+
class TestCallbackContext:
def test_slot_behaviour(self, cdp, recwarn, mro_slots):
diff --git a/tests/test_contexttypes.py b/tests/test_contexttypes.py
new file mode 100644
index 000000000..20dd405f9
--- /dev/null
+++ b/tests/test_contexttypes.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python
+#
+# A library that provides a Python interface to the Telegram Bot API
+# Copyright (C) 2021
+# Leandro Toledo de Souza
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser Public License for more details.
+#
+# You should have received a copy of the GNU Lesser Public License
+# along with this program. If not, see [http://www.gnu.org/licenses/].
+import pytest
+
+from telegram.ext import ContextTypes, CallbackContext
+
+
+class SubClass(CallbackContext):
+ pass
+
+
+class TestContextTypes:
+ def test_slot_behaviour(self, mro_slots):
+ instance = ContextTypes()
+ for attr in instance.__slots__:
+ assert getattr(instance, attr, 'err') != 'err', f"got extra slot '{attr}'"
+ assert len(mro_slots(instance)) == len(set(mro_slots(instance))), "duplicate slot"
+ with pytest.raises(AttributeError):
+ instance.custom
+
+ def test_data_init(self):
+ ct = ContextTypes(SubClass, int, float, bool)
+ assert ct.context is SubClass
+ assert ct.bot_data is int
+ assert ct.chat_data is float
+ assert ct.user_data is bool
+
+ with pytest.raises(ValueError, match='subclass of CallbackContext'):
+ ContextTypes(context=bool)
+
+ def test_data_assignment(self):
+ ct = ContextTypes()
+
+ with pytest.raises(AttributeError):
+ ct.bot_data = bool
+ with pytest.raises(AttributeError):
+ ct.user_data = bool
+ with pytest.raises(AttributeError):
+ ct.chat_data = bool
diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py
index 672250999..bcadadcd5 100644
--- a/tests/test_dispatcher.py
+++ b/tests/test_dispatcher.py
@@ -32,6 +32,7 @@ from telegram.ext import (
CallbackContext,
JobQueue,
BasePersistence,
+ ContextTypes,
)
from telegram.ext.dispatcher import run_async, Dispatcher, DispatcherHandlerStop
from telegram.utils.deprecate import TelegramDeprecationWarning
@@ -45,6 +46,10 @@ def dp2(bot):
yield from create_dp(bot)
+class CustomContext(CallbackContext):
+ pass
+
+
class TestDispatcher:
message_update = Update(
1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text')
@@ -747,6 +752,15 @@ class TestDispatcher:
def update_conversation(self, name, key, new_state):
pass
+ def refresh_bot_data(self, bot_data):
+ pass
+
+ def refresh_user_data(self, user_id, user_data):
+ pass
+
+ def refresh_chat_data(self, chat_id, chat_data):
+ pass
+
def callback(update, context):
pass
@@ -807,6 +821,15 @@ class TestDispatcher:
def get_chat_data(self):
pass
+ def refresh_bot_data(self, bot_data):
+ pass
+
+ def refresh_user_data(self, user_id, user_data):
+ pass
+
+ def refresh_chat_data(self, chat_id, chat_data):
+ pass
+
def callback(update, context):
pass
@@ -923,3 +946,62 @@ class TestDispatcher:
assert self.count == expected
finally:
dp.bot.defaults = None
+
+ def test_custom_context_init(self, bot):
+ cc = ContextTypes(
+ context=CustomContext,
+ user_data=int,
+ chat_data=float,
+ bot_data=complex,
+ )
+
+ dispatcher = Dispatcher(bot, Queue(), context_types=cc)
+
+ assert isinstance(dispatcher.user_data[1], int)
+ assert isinstance(dispatcher.chat_data[1], float)
+ assert isinstance(dispatcher.bot_data, complex)
+
+ def test_custom_context_error_handler(self, bot):
+ def error_handler(_, context):
+ self.received = (
+ type(context),
+ type(context.user_data),
+ type(context.chat_data),
+ type(context.bot_data),
+ )
+
+ dispatcher = Dispatcher(
+ bot,
+ Queue(),
+ context_types=ContextTypes(
+ context=CustomContext, bot_data=int, user_data=float, chat_data=complex
+ ),
+ )
+ dispatcher.add_error_handler(error_handler)
+ dispatcher.add_handler(MessageHandler(Filters.all, self.callback_raise_error))
+
+ dispatcher.process_update(self.message_update)
+ sleep(0.1)
+ assert self.received == (CustomContext, float, complex, int)
+
+ def test_custom_context_handler_callback(self, bot):
+ def callback(_, context):
+ self.received = (
+ type(context),
+ type(context.user_data),
+ type(context.chat_data),
+ type(context.bot_data),
+ )
+
+ dispatcher = Dispatcher(
+ bot,
+ Queue(),
+ context_types=ContextTypes(
+ context=CustomContext, bot_data=int, user_data=float, chat_data=complex
+ ),
+ )
+ dispatcher.add_handler(MessageHandler(Filters.all, callback))
+
+ dispatcher.process_update(self.message_update)
+ sleep(0.1)
+ assert self.received == (CustomContext, float, complex, int)
diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py
index d0214a069..2851827dc 100644
--- a/tests/test_jobqueue.py
+++ b/tests/test_jobqueue.py
@@ -29,7 +29,11 @@ import pytest
import pytz
from apscheduler.schedulers import SchedulerNotRunningError
from flaky import flaky
-from telegram.ext import JobQueue, Updater, Job, CallbackContext
+from telegram.ext import JobQueue, Updater, Job, CallbackContext, Dispatcher, ContextTypes
+
+
+class CustomContext(CallbackContext):
+ pass
@pytest.fixture(scope='function')
@@ -519,3 +523,25 @@ class TestJobQueue:
assert len(caplog.records) == 1
rec = caplog.records[-1]
assert 'No error handlers are registered' in rec.getMessage()
+
+ def test_custom_context(self, bot, job_queue):
+ dispatcher = Dispatcher(
+ bot,
+ Queue(),
+ context_types=ContextTypes(
+ context=CustomContext, bot_data=int, user_data=float, chat_data=complex
+ ),
+ )
+ job_queue.set_dispatcher(dispatcher)
+
+ def callback(context):
+ self.result = (
+ type(context),
+ context.user_data,
+ context.chat_data,
+ type(context.bot_data),
+ )
+
+ job_queue.run_once(callback, 0.1)
+ sleep(0.15)
+ assert self.result == (CustomContext, None, None, int)
diff --git a/tests/test_persistence.py b/tests/test_persistence.py
index 10bf010cb..0abe68c37 100644
--- a/tests/test_persistence.py
+++ b/tests/test_persistence.py
@@ -46,6 +46,7 @@ from telegram.ext import (
DictPersistence,
TypeHandler,
JobQueue,
+ ContextTypes,
)
@@ -135,12 +136,16 @@ def bot_data():
@pytest.fixture(scope="function")
def chat_data():
- return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}})
+ return defaultdict(
+ dict, {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}}
+ )
@pytest.fixture(scope="function")
def user_data():
- return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}})
+ return defaultdict(
+ dict, {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}}
+ )
@pytest.fixture(scope='function')
@@ -172,6 +177,12 @@ def job_queue(bot):
class TestBasePersistence:
+ test_flag = False
+
+ @pytest.fixture(scope='function', autouse=True)
+ def reset(self):
+ self.test_flag = False
+
def test_slot_behaviour(self, bot_persistence, mro_slots, recwarn):
inst = bot_persistence
for attr in inst.__slots__:
@@ -254,8 +265,17 @@ class TestBasePersistence:
u.dispatcher.chat_data[442233]['test5'] = 'test6'
assert u.dispatcher.chat_data[442233]['test5'] == 'test6'
+ @pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async'])
def test_dispatcher_integration_handlers(
- self, caplog, bot, base_persistence, chat_data, user_data, bot_data
+ self,
+ cdp,
+ caplog,
+ bot,
+ base_persistence,
+ chat_data,
+ user_data,
+ bot_data,
+ run_async,
):
def get_user_data():
return user_data
@@ -269,112 +289,10 @@ class TestBasePersistence:
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
base_persistence.get_bot_data = get_bot_data
- # base_persistence.update_chat_data = lambda x: x
- # base_persistence.update_user_data = lambda x: x
- updater = Updater(bot=bot, persistence=base_persistence, use_context=True)
- dp = updater.dispatcher
+ base_persistence.refresh_bot_data = lambda x: x
+ base_persistence.refresh_chat_data = lambda x, y: x
+ base_persistence.refresh_user_data = lambda x, y: x
- def callback_known_user(update, context):
- if not context.user_data['test1'] == 'test2':
- pytest.fail('user_data corrupt')
- if not context.bot_data == bot_data:
- pytest.fail('bot_data corrupt')
-
- def callback_known_chat(update, context):
- if not context.chat_data['test3'] == 'test4':
- pytest.fail('chat_data corrupt')
- if not context.bot_data == bot_data:
- pytest.fail('bot_data corrupt')
-
- def callback_unknown_user_or_chat(update, context):
- if not context.user_data == {}:
- pytest.fail('user_data corrupt')
- if not context.chat_data == {}:
- pytest.fail('chat_data corrupt')
- if not context.bot_data == bot_data:
- pytest.fail('bot_data corrupt')
- context.user_data[1] = 'test7'
- context.chat_data[2] = 'test8'
- context.bot_data['test0'] = 'test0'
-
- known_user = MessageHandler(
- Filters.user(user_id=12345),
- callback_known_user,
- pass_chat_data=True,
- pass_user_data=True,
- )
- known_chat = MessageHandler(
- Filters.chat(chat_id=-67890),
- callback_known_chat,
- pass_chat_data=True,
- pass_user_data=True,
- )
- unknown = MessageHandler(
- Filters.all, callback_unknown_user_or_chat, pass_chat_data=True, pass_user_data=True
- )
- dp.add_handler(known_user)
- dp.add_handler(known_chat)
- dp.add_handler(unknown)
- user1 = User(id=12345, first_name='test user', is_bot=False)
- user2 = User(id=54321, first_name='test user', is_bot=False)
- chat1 = Chat(id=-67890, type='group')
- chat2 = Chat(id=-987654, type='group')
- m = Message(1, None, chat2, from_user=user1)
- u = Update(0, m)
- with caplog.at_level(logging.ERROR):
- dp.process_update(u)
- rec = caplog.records[-1]
- assert rec.getMessage() == 'No error handlers are registered, logging exception.'
- assert rec.levelname == 'ERROR'
- rec = caplog.records[-2]
- assert rec.getMessage() == 'No error handlers are registered, logging exception.'
- assert rec.levelname == 'ERROR'
- rec = caplog.records[-3]
- assert rec.getMessage() == 'No error handlers are registered, logging exception.'
- assert rec.levelname == 'ERROR'
- m.from_user = user2
- m.chat = chat1
- u = Update(1, m)
- dp.process_update(u)
- m.chat = chat2
- u = Update(2, m)
-
- def save_bot_data(data):
- if 'test0' not in data:
- pytest.fail()
-
- def save_chat_data(data):
- if -987654 not in data:
- pytest.fail()
-
- def save_user_data(data):
- if 54321 not in data:
- pytest.fail()
-
- base_persistence.update_chat_data = save_chat_data
- base_persistence.update_user_data = save_user_data
- base_persistence.update_bot_data = save_bot_data
- dp.process_update(u)
-
- assert dp.user_data[54321][1] == 'test7'
- assert dp.chat_data[-987654][2] == 'test8'
- assert dp.bot_data['test0'] == 'test0'
-
- def test_dispatcher_integration_handlers_run_async(
- self, cdp, caplog, bot, base_persistence, chat_data, user_data, bot_data
- ):
- def get_user_data():
- return user_data
-
- def get_chat_data():
- return chat_data
-
- def get_bot_data():
- return bot_data
-
- base_persistence.get_user_data = get_user_data
- base_persistence.get_chat_data = get_chat_data
- base_persistence.get_bot_data = get_bot_data
cdp.persistence = base_persistence
cdp.user_data = user_data
cdp.chat_data = chat_data
@@ -408,21 +326,21 @@ class TestBasePersistence:
callback_known_user,
pass_chat_data=True,
pass_user_data=True,
- run_async=True,
+ run_async=run_async,
)
known_chat = MessageHandler(
Filters.chat(chat_id=-67890),
callback_known_chat,
pass_chat_data=True,
pass_user_data=True,
- run_async=True,
+ run_async=run_async,
)
unknown = MessageHandler(
Filters.all,
callback_unknown_user_or_chat,
pass_chat_data=True,
pass_user_data=True,
- run_async=True,
+ run_async=run_async,
)
cdp.add_handler(known_user)
cdp.add_handler(known_chat)
@@ -437,12 +355,16 @@ class TestBasePersistence:
cdp.process_update(u)
sleep(0.1)
- rec = caplog.records[-1]
- assert rec.getMessage() == 'No error handlers are registered, logging exception.'
- assert rec.levelname == 'ERROR'
- rec = caplog.records[-2]
- assert rec.getMessage() == 'No error handlers are registered, logging exception.'
- assert rec.levelname == 'ERROR'
+
+ # In base_persistence.update_*_data we currently just raise NotImplementedError
+ # This makes sure that this doesn't break the processing and is properly handled by
+ # the error handler
+ # We override `update_*_data` further below.
+ assert len(caplog.records) == 3
+ for rec in caplog.records:
+ assert rec.getMessage() == 'No error handlers are registered, logging exception.'
+ assert rec.levelname == 'ERROR'
+
m.from_user = user2
m.chat = chat1
u = Update(1, m)
@@ -473,6 +395,105 @@ class TestBasePersistence:
assert cdp.chat_data[-987654][2] == 'test8'
assert cdp.bot_data['test0'] == 'test0'
+ @pytest.mark.parametrize(
+ 'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False']
+ )
+ @pytest.mark.parametrize(
+ 'store_chat_data', [True, False], ids=['store_chat_data-True', 'store_chat_data-False']
+ )
+ @pytest.mark.parametrize(
+ 'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False']
+ )
+ @pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async'])
+ def test_persistence_dispatcher_integration_refresh_data(
+ self,
+ cdp,
+ base_persistence,
+ chat_data,
+ bot_data,
+ user_data,
+ store_bot_data,
+ store_chat_data,
+ store_user_data,
+ run_async,
+ ):
+ base_persistence.refresh_bot_data = lambda x: x.setdefault(
+ 'refreshed', x.get('refreshed', 0) + 1
+ )
+ # x is the user/chat_id
+ base_persistence.refresh_chat_data = lambda x, y: y.setdefault('refreshed', x)
+ base_persistence.refresh_user_data = lambda x, y: y.setdefault('refreshed', x)
+ base_persistence.store_bot_data = store_bot_data
+ base_persistence.store_chat_data = store_chat_data
+ base_persistence.store_user_data = store_user_data
+ cdp.persistence = base_persistence
+
+ self.test_flag = True
+
+ def callback_with_user_and_chat(update, context):
+ if store_user_data:
+ if context.user_data.get('refreshed') != update.effective_user.id:
+ self.test_flag = 'user_data was not refreshed'
+ else:
+ if 'refreshed' in context.user_data:
+ self.test_flag = 'user_data was wrongly refreshed'
+ if store_chat_data:
+ if context.chat_data.get('refreshed') != update.effective_chat.id:
+ self.test_flag = 'chat_data was not refreshed'
+ else:
+ if 'refreshed' in context.chat_data:
+ self.test_flag = 'chat_data was wrongly refreshed'
+ if store_bot_data:
+ if context.bot_data.get('refreshed') != 1:
+ self.test_flag = 'bot_data was not refreshed'
+ else:
+ if 'refreshed' in context.bot_data:
+ self.test_flag = 'bot_data was wrongly refreshed'
+
+ def callback_without_user_and_chat(_, context):
+ if store_bot_data:
+ if context.bot_data.get('refreshed') != 1:
+ self.test_flag = 'bot_data was not refreshed'
+ else:
+ if 'refreshed' in context.bot_data:
+ self.test_flag = 'bot_data was wrongly refreshed'
+
+ with_user_and_chat = MessageHandler(
+ Filters.user(user_id=12345),
+ callback_with_user_and_chat,
+ pass_chat_data=True,
+ pass_user_data=True,
+ run_async=run_async,
+ )
+ without_user_and_chat = MessageHandler(
+ Filters.all,
+ callback_without_user_and_chat,
+ pass_chat_data=True,
+ pass_user_data=True,
+ run_async=run_async,
+ )
+ cdp.add_handler(with_user_and_chat)
+ cdp.add_handler(without_user_and_chat)
+ user = User(id=12345, first_name='test user', is_bot=False)
+ chat = Chat(id=-987654, type='group')
+ m = Message(1, None, chat, from_user=user)
+
+ # has user and chat
+ u = Update(0, m)
+ cdp.process_update(u)
+
+ assert self.test_flag is True
+
+ # has neither user nor hat
+ m.from_user = None
+ m.chat = None
+ u = Update(1, m)
+ cdp.process_update(u)
+
+ assert self.test_flag is True
+
+ sleep(0.1)
+
def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog):
# Updates used with TypeHandler doesn't necessarily have the proper attributes for
# persistence, makes sure it works anyways
@@ -816,6 +837,10 @@ def update(bot):
return Update(0, message=message)
+class CustomMapping(defaultdict):
+ pass
+
+
class TestPicklePersistence:
def test_slot_behaviour(self, mro_slots, recwarn, pickle_persistence):
inst = pickle_persistence
@@ -986,25 +1011,34 @@ class TestPicklePersistence:
def test_updating_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data()
- user_data[54321]['test9'] = 'test 10'
+ user_data[12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.user_data == user_data
- pickle_persistence.update_user_data(54321, user_data[54321])
+ pickle_persistence.update_user_data(12345, user_data[12345])
+ user_data[12345]['test3']['test4'] = 'test7'
+ assert not pickle_persistence.user_data == user_data
+ pickle_persistence.update_user_data(12345, user_data[12345])
assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
- chat_data[54321]['test9'] = 'test 10'
+ chat_data[-12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.chat_data == chat_data
- pickle_persistence.update_chat_data(54321, chat_data[54321])
+ pickle_persistence.update_chat_data(-12345, chat_data[-12345])
+ chat_data[-12345]['test3']['test4'] = 'test7'
+ assert not pickle_persistence.chat_data == chat_data
+ pickle_persistence.update_chat_data(-12345, chat_data[-12345])
assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data
bot_data = pickle_persistence.get_bot_data()
- bot_data['test6'] = 'test 7'
+ bot_data['test3']['test4'] = 'test6'
+ assert not pickle_persistence.bot_data == bot_data
+ pickle_persistence.update_bot_data(bot_data)
+ bot_data['test3']['test4'] = 'test7'
assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data)
assert pickle_persistence.bot_data == bot_data
@@ -1031,25 +1065,34 @@ class TestPicklePersistence:
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
- user_data[54321]['test9'] = 'test 10'
+ user_data[12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.user_data == user_data
- pickle_persistence.update_user_data(54321, user_data[54321])
+ pickle_persistence.update_user_data(12345, user_data[12345])
+ user_data[12345]['test3']['test4'] = 'test7'
+ assert not pickle_persistence.user_data == user_data
+ pickle_persistence.update_user_data(12345, user_data[12345])
assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
- chat_data[54321]['test9'] = 'test 10'
+ chat_data[-12345]['test3']['test4'] = 'test6'
assert not pickle_persistence.chat_data == chat_data
- pickle_persistence.update_chat_data(54321, chat_data[54321])
+ pickle_persistence.update_chat_data(-12345, chat_data[-12345])
+ chat_data[-12345]['test3']['test4'] = 'test7'
+ assert not pickle_persistence.chat_data == chat_data
+ pickle_persistence.update_chat_data(-12345, chat_data[-12345])
assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data
bot_data = pickle_persistence.get_bot_data()
- bot_data['test6'] = 'test 7'
+ bot_data['test3']['test4'] = 'test6'
+ assert not pickle_persistence.bot_data == bot_data
+ pickle_persistence.update_bot_data(bot_data)
+ bot_data['test3']['test4'] = 'test7'
assert not pickle_persistence.bot_data == bot_data
pickle_persistence.update_bot_data(bot_data)
assert pickle_persistence.bot_data == bot_data
@@ -1418,6 +1461,39 @@ class TestPicklePersistence:
user_data = pickle_persistence.get_user_data()
assert user_data[789] == {'test3': '123'}
+ @pytest.mark.parametrize('singlefile', [True, False])
+ @pytest.mark.parametrize('ud', [int, float, complex])
+ @pytest.mark.parametrize('cd', [int, float, complex])
+ @pytest.mark.parametrize('bd', [int, float, complex])
+ def test_with_context_types(self, ud, cd, bd, singlefile):
+ cc = ContextTypes(user_data=ud, chat_data=cd, bot_data=bd)
+ persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc)
+
+ assert isinstance(persistence.get_user_data()[1], ud)
+ assert persistence.get_user_data()[1] == 0
+ assert isinstance(persistence.get_chat_data()[1], cd)
+ assert persistence.get_chat_data()[1] == 0
+ assert isinstance(persistence.get_bot_data(), bd)
+ assert persistence.get_bot_data() == 0
+
+ persistence.user_data = None
+ persistence.chat_data = None
+ persistence.update_user_data(1, ud(1))
+ persistence.update_chat_data(1, cd(1))
+ persistence.update_bot_data(bd(1))
+ assert persistence.get_user_data()[1] == 1
+ assert persistence.get_chat_data()[1] == 1
+ assert persistence.get_bot_data() == 1
+
+ persistence.flush()
+ persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc)
+ assert isinstance(persistence.get_user_data()[1], ud)
+ assert persistence.get_user_data()[1] == 1
+ assert isinstance(persistence.get_chat_data()[1], cd)
+ assert persistence.get_chat_data()[1] == 1
+ assert isinstance(persistence.get_bot_data(), bd)
+ assert persistence.get_bot_data() == 1
+
@pytest.fixture(scope='function')
def user_data_json(user_data):
@@ -1560,7 +1636,7 @@ class TestDictPersistence:
assert dict_persistence.bot_data_json == bot_data_json
assert dict_persistence.conversations_json == conversations_json
- def test_json_changes(
+ def test_updating(
self,
user_data,
user_data_json,
@@ -1577,35 +1653,59 @@ class TestDictPersistence:
bot_data_json=bot_data_json,
conversations_json=conversations_json,
)
- user_data_two = user_data.copy()
- user_data_two.update({4: {5: 6}})
- dict_persistence.update_user_data(4, {5: 6})
- assert dict_persistence.user_data == user_data_two
- assert dict_persistence.user_data_json != user_data_json
- assert dict_persistence.user_data_json == json.dumps(user_data_two)
- chat_data_two = chat_data.copy()
- chat_data_two.update({7: {8: 9}})
- dict_persistence.update_chat_data(7, {8: 9})
- assert dict_persistence.chat_data == chat_data_two
- assert dict_persistence.chat_data_json != chat_data_json
- assert dict_persistence.chat_data_json == json.dumps(chat_data_two)
+ user_data = dict_persistence.get_user_data()
+ user_data[12345]['test3']['test4'] = 'test6'
+ assert not dict_persistence.user_data == user_data
+ assert not dict_persistence.user_data_json == json.dumps(user_data)
+ dict_persistence.update_user_data(12345, user_data[12345])
+ user_data[12345]['test3']['test4'] = 'test7'
+ assert not dict_persistence.user_data == user_data
+ assert not dict_persistence.user_data_json == json.dumps(user_data)
+ dict_persistence.update_user_data(12345, user_data[12345])
+ assert dict_persistence.user_data == user_data
+ assert dict_persistence.user_data_json == json.dumps(user_data)
- bot_data_two = bot_data.copy()
- bot_data_two.update({'7': {'8': '9'}})
- bot_data['7'] = {'8': '9'}
+ chat_data = dict_persistence.get_chat_data()
+ chat_data[-12345]['test3']['test4'] = 'test6'
+ assert not dict_persistence.chat_data == chat_data
+ assert not dict_persistence.chat_data_json == json.dumps(chat_data)
+ dict_persistence.update_chat_data(-12345, chat_data[-12345])
+ chat_data[-12345]['test3']['test4'] = 'test7'
+ assert not dict_persistence.chat_data == chat_data
+ assert not dict_persistence.chat_data_json == json.dumps(chat_data)
+ dict_persistence.update_chat_data(-12345, chat_data[-12345])
+ assert dict_persistence.chat_data == chat_data
+ assert dict_persistence.chat_data_json == json.dumps(chat_data)
+
+ bot_data = dict_persistence.get_bot_data()
+ bot_data['test3']['test4'] = 'test6'
+ assert not dict_persistence.bot_data == bot_data
+ assert not dict_persistence.bot_data_json == json.dumps(bot_data)
dict_persistence.update_bot_data(bot_data)
- assert dict_persistence.bot_data == bot_data_two
- assert dict_persistence.bot_data_json != bot_data_json
- assert dict_persistence.bot_data_json == json.dumps(bot_data_two)
+ bot_data['test3']['test4'] = 'test7'
+ assert not dict_persistence.bot_data == bot_data
+ assert not dict_persistence.bot_data_json == json.dumps(bot_data)
+ dict_persistence.update_bot_data(bot_data)
+ assert dict_persistence.bot_data == bot_data
+ assert dict_persistence.bot_data_json == json.dumps(bot_data)
- conversations_two = conversations.copy()
- conversations_two.update({'name4': {(1, 2): 3}})
- dict_persistence.update_conversation('name4', (1, 2), 3)
- assert dict_persistence.conversations == conversations_two
- assert dict_persistence.conversations_json != conversations_json
+ conversation1 = dict_persistence.get_conversations('name1')
+ conversation1[(123, 123)] = 5
+ assert not dict_persistence.conversations['name1'] == conversation1
+ dict_persistence.update_conversation('name1', (123, 123), 5)
+ assert dict_persistence.conversations['name1'] == conversation1
+ print(dict_persistence.conversations_json)
+ conversations['name1'][(123, 123)] = 5
+ assert dict_persistence.conversations_json == encode_conversations_to_json(conversations)
+ assert dict_persistence.get_conversations('name1') == conversation1
+
+ dict_persistence._conversations = None
+ dict_persistence.update_conversation('name1', (123, 123), 5)
+ assert dict_persistence.conversations['name1'] == {(123, 123): 5}
+ assert dict_persistence.get_conversations('name1') == {(123, 123): 5}
assert dict_persistence.conversations_json == encode_conversations_to_json(
- conversations_two
+ {"name1": {(123, 123): 5}}
)
def test_with_handler(self, bot, update):
diff --git a/tests/test_slots.py b/tests/test_slots.py
index e97a4e178..9d5169eb3 100644
--- a/tests/test_slots.py
+++ b/tests/test_slots.py
@@ -31,6 +31,7 @@ excluded = {
'Days',
'telegram.deprecate',
'TelegramDecryptionError',
+ 'ContextTypes',
} # These modules/classes intentionally don't have __dict__.
diff --git a/tests/test_updater.py b/tests/test_updater.py
index 16c3c611a..9eda467a6 100644
--- a/tests/test_updater.py
+++ b/tests/test_updater.py
@@ -613,6 +613,11 @@ class TestUpdater:
with pytest.raises(ValueError):
Updater(dispatcher=dispatcher, use_context=use_context)
+ def test_mutual_exclude_custom_context_dispatcher(self):
+ dispatcher = Dispatcher(None, None)
+ with pytest.raises(ValueError):
+ Updater(dispatcher=dispatcher, context_types=True)
+
def test_defaults_warning(self, bot):
with pytest.warns(TelegramDeprecationWarning, match='no effect when a Bot is passed'):
Updater(bot=bot, defaults=Defaults())