Persistence of Bots: Refactor Automatic Replacement and Integration with TelegramObject (#2893)

This commit is contained in:
Harshil 2022-03-12 15:27:18 +04:00 committed by Hinrich Mahler
parent 7b37f9a6fa
commit a743726b08
12 changed files with 464 additions and 690 deletions

View file

@ -22,6 +22,7 @@
import functools
import logging
import pickle
from datetime import datetime
from typing import (
@ -37,6 +38,7 @@ from typing import (
cast,
Sequence,
Any,
NoReturn,
)
try:
@ -123,10 +125,13 @@ class Bot(TelegramObject):
considered equal, if their :attr:`bot` is equal.
Note:
Most bot methods have the argument ``api_kwargs`` which allows to pass arbitrary keywords
to the Telegram API. This can be used to access new features of the API before they were
incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for
passing files.
* Most bot methods have the argument ``api_kwargs`` which allows passing arbitrary keywords
to the Telegram API. This can be used to access new features of the API before they are
incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for
passing files.
* Bots should not be serialized since if you for e.g. change the bots token, then your
serialized instance will not reflect that change. Trying to pickle a bot instance will
raise :exc:`pickle.PicklingError`.
.. versionchanged:: 14.0
@ -136,6 +141,7 @@ class Bot(TelegramObject):
* Removed the deprecated ``defaults`` parameter. If you want to use
:class:`telegram.ext.Defaults`, please use the subclass :class:`telegram.ext.ExtBot`
instead.
* Attempting to pickle a bot instance will now raise :exc:`pickle.PicklingError`.
Args:
token (:obj:`str`): Bot's unique authentication.
@ -157,7 +163,7 @@ class Bot(TelegramObject):
'private_key',
'_bot_user',
'_request',
'logger',
'_logger',
)
def __init__(
@ -176,7 +182,7 @@ class Bot(TelegramObject):
self._bot_user: Optional[User] = None
self._request = request or Request()
self.private_key = None
self.logger = logging.getLogger(__name__)
self._logger = logging.getLogger(__name__)
if private_key:
if not CRYPTO_INSTALLED:
@ -188,6 +194,10 @@ class Bot(TelegramObject):
private_key, password=private_key_password, backend=default_backend()
)
def __reduce__(self) -> NoReturn:
"""Called by pickle.dumps(). Serializing bots is unadvisable, so we forbid pickling."""
raise pickle.PicklingError('Bot objects cannot be pickled!')
# TODO: After https://youtrack.jetbrains.com/issue/PY-50952 is fixed, we can revisit this and
# consider adding Paramspec from typing_extensions to properly fix this. Currently a workaround
def _log(func: Any): # type: ignore[no-untyped-def] # skipcq: PY-D0003
@ -2999,9 +3009,9 @@ class Bot(TelegramObject):
)
if result:
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
self._logger.debug('Getting updates: %s', [u['update_id'] for u in result])
else:
self.logger.debug('No new updates found.')
self._logger.debug('No new updates found.')
return Update.de_list(result, self) # type: ignore[return-value]

View file

@ -51,7 +51,7 @@ class _BaseMedium(TelegramObject):
"""
__slots__ = ('bot', 'file_id', 'file_size', 'file_unique_id')
__slots__ = ('file_id', 'file_size', 'file_unique_id')
def __init__(
self, file_id: str, file_unique_id: str, file_size: int = None, bot: 'Bot' = None

View file

@ -210,14 +210,14 @@ class Credentials(TelegramObject):
nonce (:obj:`str`): Bot-specified nonce
"""
__slots__ = ('bot', 'nonce', 'secure_data')
__slots__ = ('nonce', 'secure_data')
def __init__(self, secure_data: 'SecureData', nonce: str, bot: 'Bot' = None, **_kwargs: Any):
# Required
self.secure_data = secure_data
self.nonce = nonce
self.bot = bot
self.set_bot(bot)
@classmethod
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['Credentials']:
@ -261,7 +261,6 @@ class SecureData(TelegramObject):
"""
__slots__ = (
'bot',
'utility_bill',
'personal_details',
'temporary_registration',
@ -304,7 +303,7 @@ class SecureData(TelegramObject):
self.passport = passport
self.personal_details = personal_details
self.bot = bot
self.set_bot(bot)
@classmethod
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureData']:
@ -360,7 +359,7 @@ class SecureValue(TelegramObject):
"""
__slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation', 'bot')
__slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation')
def __init__(
self,
@ -380,7 +379,7 @@ class SecureValue(TelegramObject):
self.files = files
self.translation = translation
self.bot = bot
self.set_bot(bot)
@classmethod
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureValue']:
@ -412,17 +411,17 @@ class SecureValue(TelegramObject):
class _CredentialsBase(TelegramObject):
"""Base class for DataCredentials and FileCredentials."""
__slots__ = ('hash', 'secret', 'file_hash', 'data_hash', 'bot')
__slots__ = ('hash', 'secret', 'file_hash', 'data_hash')
def __init__(self, hash: str, secret: str, bot: 'Bot' = None, **_kwargs: Any):
self.hash = hash
self.secret = secret
# Aliases just be be sure
# Aliases just to be sure
self.file_hash = self.hash
self.data_hash = self.hash
self.bot = bot
self.set_bot(bot)
class DataCredentials(_CredentialsBase):

View file

@ -17,12 +17,14 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""Base class for Telegram Objects."""
from copy import deepcopy
try:
import ujson as json
except ImportError:
import json # type: ignore[no-redef]
from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple
from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple, Dict, Union
from telegram._utils.types import JSONDict
from telegram._utils.warnings import warn
@ -40,6 +42,12 @@ class TelegramObject:
is equivalent to ``telegram_object.attribute_name``. If the object does not have an attribute
with the appropriate name, a :exc:`KeyError` will be raised.
When objects of this type are pickled, the :class:`~telegram.Bot` attribute associated with the
object will be removed. However, when copying the object via :func:`copy.deepcopy`, the copy
will have the *same* bot instance associated with it, i.e::
assert telegram_object.get_bot() is copy.deepcopy(telegram_object).get_bot()
.. versionchanged:: 14.0
``telegram_object['from']`` will look up the key ``from_user``. This is to account for
special cases like :attr:`Message.from_user` that deviate from the official Bot API.
@ -53,15 +61,12 @@ class TelegramObject:
_bot: Optional['Bot']
# Adding slots reduces memory usage & allows for faster attribute access.
# Only instance variables should be added to __slots__.
__slots__ = (
'_id_attrs',
'_bot',
)
__slots__ = ('_id_attrs', '_bot')
# pylint: disable=unused-argument
def __new__(cls, *args: object, **kwargs: object) -> 'TelegramObject':
# We add _id_attrs in __new__ instead of __init__ since we want to add this to the slots
# w/o calling __init__ in all of the subclasses. This is what we also do in BaseFilter.
# w/o calling __init__ in all of the subclasses.
instance = super().__new__(cls)
instance._id_attrs = ()
instance._bot = None
@ -81,6 +86,86 @@ class TelegramObject:
f"`{item}`."
) from exc
def __getstate__(self) -> Dict[str, Union[str, object]]:
"""
This method is used for pickling. We remove the bot attribute of the object since those
are not pickable.
"""
return self._get_attrs(include_private=True, recursive=False, remove_bot=True)
def __setstate__(self, state: dict) -> None:
"""
This method is used for unpickling. The data, which is in the form a dictionary, is
converted back into a class. Should be modified in place.
"""
for key, val in state.items():
setattr(self, key, val)
def __deepcopy__(self: TO, memodict: dict) -> TO:
"""This method deepcopies the object and sets the bot on the newly created copy."""
bot = self._bot # Save bot so we can set it after copying
self.set_bot(None) # set to None so it is not deepcopied
cls = self.__class__
result = cls.__new__(cls) # create a new instance
memodict[id(self)] = result # save the id of the object in the dict
attrs = self._get_attrs(include_private=True) # get all its attributes
for k in attrs: # now we set the attributes in the deepcopied object
setattr(result, k, deepcopy(getattr(self, k), memodict))
result.set_bot(bot) # Assign the bots back
self.set_bot(bot)
return result # type: ignore[return-value]
def _get_attrs(
self,
include_private: bool = False,
recursive: bool = False,
remove_bot: bool = False,
) -> Dict[str, Union[str, object]]:
"""This method is used for obtaining the attributes of the object.
Args:
include_private (:obj:`bool`): Whether the result should include private variables.
recursive (:obj:`bool`): If :obj:`True`, will convert any TelegramObjects (if found) in
the attributes to a dictionary. Else, preserves it as an object itself.
remove_bot (:obj:`bool`): Whether the bot should be included in the result.
Returns:
:obj:`dict`: A dict where the keys are attribute names and values are their values.
"""
data = {}
if not recursive:
try:
# __dict__ has attrs from superclasses, so no need to put in the for loop below
data.update(self.__dict__)
except AttributeError:
pass
# We want to get all attributes for the class, using self.__slots__ only includes the
# attributes used by that class itself, and not its superclass(es). Hence, we get its MRO
# and then get their attributes. The `[:-1]` slice excludes the `object` class
for cls in self.__class__.__mro__[:-1]:
for key in cls.__slots__:
if not include_private and key.startswith('_'):
continue
value = getattr(self, key, None)
if value is not None:
if recursive and hasattr(value, 'to_dict'):
data[key] = value.to_dict()
else:
data[key] = value
elif not recursive:
data[key] = value
if recursive and data.get('from_user'):
data['from'] = data.pop('from_user', None)
if remove_bot:
data.pop('_bot', None)
return data
@staticmethod
def _parse_data(data: Optional[JSONDict]) -> Optional[JSONDict]:
return None if data is None else data.copy()
@ -137,27 +222,7 @@ class TelegramObject:
Returns:
:obj:`dict`
"""
data = {}
# We want to get all attributes for the class, using self.__slots__ only includes the
# attributes used by that class itself, and not its superclass(es). Hence we get its MRO
# and then get their attributes. The `[:-2]` slice excludes the `object` class & the
# TelegramObject class itself.
attrs = {attr for cls in self.__class__.__mro__[:-2] for attr in cls.__slots__}
for key in attrs:
if key == 'bot' or key.startswith('_'):
continue
value = getattr(self, key, None)
if value is not None:
if hasattr(value, 'to_dict'):
data[key] = value.to_dict()
else:
data[key] = value
if data.get('from_user'):
data['from'] = data.pop('from_user', None)
return data
return self._get_attrs(recursive=True)
def get_bot(self) -> 'Bot':
"""Returns the :class:`telegram.Bot` instance associated with this object.
@ -171,8 +236,7 @@ class TelegramObject:
"""
if self._bot is None:
raise RuntimeError(
'This object has no bot associated with it. \
Shortcuts cannot be used.'
'This object has no bot associated with it. Shortcuts cannot be used.'
)
return self._bot

View file

@ -18,14 +18,11 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""
from abc import ABC, abstractmethod
from copy import copy
from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, NamedTuple
from typing import Dict, Optional, Tuple, Generic, NamedTuple
from telegram import Bot
from telegram.ext import ExtBot
from telegram.warnings import PTBRuntimeWarning
from telegram._utils.warnings import warn
from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData
@ -73,33 +70,39 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
* :meth:`get_chat_data`
* :meth:`update_chat_data`
* :meth:`refresh_chat_data`
* :meth:`drop_chat_data`
* :meth:`get_user_data`
* :meth:`update_user_data`
* :meth:`refresh_user_data`
* :meth:`drop_user_data`
* :meth:`get_callback_data`
* :meth:`update_callback_data`
* :meth:`get_conversations`
* :meth:`update_conversation`
* :meth:`flush`
If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if
you don't store ``bot_data``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or
:meth:`refresh_bot_data`.
Warning:
Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and
insert the bot set with :meth:`set_bot` upon loading of the data. This is to ensure that
changes to the bot apply to the saved objects, too. If you change the bots token, this may
lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
:meth:`replace_bot` and :meth:`insert_bot`.
If you don't actually need one of those methods, a simple :keyword:`pass` is enough.
For example, if you don't store ``bot_data``, you don't need :meth:`get_bot_data`,
:meth:`update_bot_data` or :meth:`refresh_bot_data`.
Note:
:meth:`replace_bot` and :meth:`insert_bot` are used *independently* of the implementation
of the ``update/get_*`` methods, i.e. you don't need to worry about it while
implementing a custom persistence subclass.
You should avoid saving :class:`telegram.Bot` instances. This is because if you change e.g.
the bots token, this won't propagate to the serialized instances and may lead to exceptions.
To prevent this, the implementation may use :attr:`bot` to replace bot instances with a
placeholder before serialization and insert :attr:`bot` back when loading the data.
Since :attr:`bot` will be set when the process starts, this will be the up-to-date bot
instance.
If the persistence implementation does not take care of this, you should make sure not to
store any bot instances in the data that will be persisted. E.g. in case of
:class:`telegram.TelegramObject`, one may call :meth:`set_bot` to ensure that shortcuts like
:meth:`telegram.Message.reply_text` are available.
.. versionchanged:: 14.0
The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
* The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
* ``insert/replace_bot`` was dropped. Serialization of bot instances now needs to be
handled by the specific implementation - see above note.
Args:
store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be
@ -109,72 +112,10 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
Attributes:
store_data (:class:`PersistenceInput`): Specifies which kinds of data will be saved by this
persistence instance.
bot (:class:`telegram.Bot`): The bot associated with the persistence.
"""
__slots__ = (
'bot',
'store_data',
'__dict__', # __dict__ is included because we replace methods in the __new__
)
def __new__(
cls, *args: object, **kwargs: object # pylint: disable=unused-argument
) -> 'BasePersistence':
"""This overrides the get_* and update_* methods to use insert/replace_bot.
That has the side effect that we always pass deepcopied data to those methods, so in
Pickle/DictPersistence we don't have to worry about copying the data again.
Note: This doesn't hold for second tuple-entry of callback_data. That's a Dict[str, str],
so no bots to replace anyway.
"""
instance = super().__new__(cls)
get_user_data = instance.get_user_data
get_chat_data = instance.get_chat_data
get_bot_data = instance.get_bot_data
get_callback_data = instance.get_callback_data
update_user_data = instance.update_user_data
update_chat_data = instance.update_chat_data
update_bot_data = instance.update_bot_data
update_callback_data = instance.update_callback_data
def get_user_data_insert_bot() -> Dict[int, UD]:
return instance.insert_bot(get_user_data())
def get_chat_data_insert_bot() -> Dict[int, CD]:
return instance.insert_bot(get_chat_data())
def get_bot_data_insert_bot() -> BD:
return instance.insert_bot(get_bot_data())
def get_callback_data_insert_bot() -> Optional[CDCData]:
cdc_data = get_callback_data()
if cdc_data is None:
return None
return instance.insert_bot(cdc_data[0]), cdc_data[1]
def update_user_data_replace_bot(user_id: int, data: UD) -> None:
return update_user_data(user_id, instance.replace_bot(data))
def update_chat_data_replace_bot(chat_id: int, data: CD) -> None:
return update_chat_data(chat_id, instance.replace_bot(data))
def update_bot_data_replace_bot(data: BD) -> None:
return update_bot_data(instance.replace_bot(data))
def update_callback_data_replace_bot(data: CDCData) -> None:
obj_data, queue = data
return update_callback_data((instance.replace_bot(obj_data), queue))
# Adds to __dict__
setattr(instance, 'get_user_data', get_user_data_insert_bot)
setattr(instance, 'get_chat_data', get_chat_data_insert_bot)
setattr(instance, 'get_bot_data', get_bot_data_insert_bot)
setattr(instance, 'get_callback_data', get_callback_data_insert_bot)
setattr(instance, 'update_user_data', update_user_data_replace_bot)
setattr(instance, 'update_chat_data', update_chat_data_replace_bot)
setattr(instance, 'update_bot_data', update_bot_data_replace_bot)
setattr(instance, 'update_callback_data', update_callback_data_replace_bot)
return instance
__slots__ = ('bot', 'store_data')
def __init__(self, store_data: PersistenceInput = None):
self.store_data = store_data or PersistenceInput()
@ -192,215 +133,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
self.bot = bot
@classmethod
def replace_bot(cls, obj: object) -> object:
"""
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
:attr:`REPLACED_BOT`. Currently, this handles objects of type :class:`list`,
:class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`,
:class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or
:data:`~object.__slots__` attribute, excluding classes and objects that can't be copied
with :func:`copy.copy`. If the parsing of an object fails, the object will be returned
unchanged and the error will be logged.
Args:
obj (:obj:`object`): The object
Returns:
:class:`object`: Copy of the object with Bot instances replaced.
"""
return cls._replace_bot(obj, {})
@classmethod
def _replace_bot( # pylint: disable=too-many-return-statements
cls, obj: object, memo: Dict[int, object]
) -> object:
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, Bot):
memo[obj_id] = cls.REPLACED_BOT
return cls.REPLACED_BOT
if isinstance(obj, (list, set)):
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
# over doesn't change its length during the iteration
temp_iterable = obj.copy()
new_iterable = obj.__class__(cls._replace_bot(item, memo) for item in temp_iterable)
memo[obj_id] = new_iterable
return new_iterable
if isinstance(obj, (tuple, frozenset)):
# tuples and frozensets are immutable so we don't need to worry about thread safety
new_immutable = obj.__class__(cls._replace_bot(item, memo) for item in obj)
memo[obj_id] = new_immutable
return new_immutable
if isinstance(obj, type):
# classes usually do have a __dict__, but it's not writable
warn(
f'BasePersistence.replace_bot does not handle classes such as {obj.__name__!r}. '
'See the docs of BasePersistence.replace_bot for more information.',
PTBRuntimeWarning,
)
return obj
try:
new_obj = copy(obj)
memo[obj_id] = new_obj
except Exception:
warn(
'BasePersistence.replace_bot does not handle objects that can not be copied. See '
'the docs of BasePersistence.replace_bot for more information.',
PTBRuntimeWarning,
)
memo[obj_id] = obj
return obj
if isinstance(obj, dict):
# We handle dicts via copy(obj) so we don't have to make a
# difference between dict and defaultdict
new_obj = cast(dict, new_obj)
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
# change during the iteration
temp_dict = new_obj.copy()
new_obj.clear()
for k, val in temp_dict.items():
new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo)
memo[obj_id] = new_obj
return new_obj
try:
if hasattr(obj, '__slots__'):
for attr_name in new_obj.__slots__:
setattr(
new_obj,
attr_name,
cls._replace_bot(
cls._replace_bot(getattr(new_obj, attr_name), memo), memo
),
)
if '__dict__' in obj.__slots__:
# In this case, we have already covered the case that obj has __dict__
# Note that obj may have a __dict__ even if it's not in __slots__!
memo[obj_id] = new_obj
return new_obj
if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, cls._replace_bot(attr, memo))
memo[obj_id] = new_obj
return new_obj
except Exception as exception:
warn(
f'Parsing of an object failed with the following exception: {exception}. '
f'See the docs of BasePersistence.replace_bot for more information.',
PTBRuntimeWarning,
)
memo[obj_id] = obj
return obj
def insert_bot(self, obj: object) -> object:
"""
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
:paramref:`bot`. Currently, this handles objects of type :class:`list`,
:class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`,
:class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or
:data:`~object.__slots__` attribute, excluding classes and objects that can't be copied
with :func:`copy.copy`. If the parsing of an object fails, the object will be returned
unchanged and the error will be logged.
Args:
obj (:obj:`object`): The object
Returns:
:class:`object`: Copy of the object with Bot instances inserted.
"""
return self._insert_bot(obj, {})
# pylint: disable=too-many-return-statements
def _insert_bot(self, obj: object, memo: Dict[int, object]) -> object:
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, Bot):
memo[obj_id] = self.bot
return self.bot
if isinstance(obj, str) and obj == self.REPLACED_BOT:
memo[obj_id] = self.bot
return self.bot
if isinstance(obj, (list, set)):
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
# over doesn't change its length during the iteration
temp_iterable = obj.copy()
new_iterable = obj.__class__(self._insert_bot(item, memo) for item in temp_iterable)
memo[obj_id] = new_iterable
return new_iterable
if isinstance(obj, (tuple, frozenset)):
# tuples and frozensets are immutable so we don't need to worry about thread safety
new_immutable = obj.__class__(self._insert_bot(item, memo) for item in obj)
memo[obj_id] = new_immutable
return new_immutable
if isinstance(obj, type):
# classes usually do have a __dict__, but it's not writable
warn(
f'BasePersistence.insert_bot does not handle classes such as {obj.__name__!r}. '
'See the docs of BasePersistence.insert_bot for more information.',
PTBRuntimeWarning,
)
return obj
try:
new_obj = copy(obj)
except Exception:
warn(
'BasePersistence.insert_bot does not handle objects that can not be copied. See '
'the docs of BasePersistence.insert_bot for more information.',
PTBRuntimeWarning,
)
memo[obj_id] = obj
return obj
if isinstance(obj, dict):
# We handle dicts via copy(obj) so we don't have to make a
# difference between dict and defaultdict
new_obj = cast(dict, new_obj)
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
# change during the iteration
temp_dict = new_obj.copy()
new_obj.clear()
for k, val in temp_dict.items():
new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo)
memo[obj_id] = new_obj
return new_obj
try:
if hasattr(obj, '__slots__'):
for attr_name in obj.__slots__:
setattr(
new_obj,
attr_name,
self._insert_bot(
self._insert_bot(getattr(new_obj, attr_name), memo), memo
),
)
if '__dict__' in obj.__slots__:
# In this case, we have already covered the case that obj has __dict__
# Note that obj may have a __dict__ even if it's not in __slots__!
memo[obj_id] = new_obj
return new_obj
if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, self._insert_bot(attr, memo))
memo[obj_id] = new_obj
return new_obj
except Exception as exception:
warn(
f'Parsing of an object failed with the following exception: {exception}. '
f'See the docs of BasePersistence.insert_bot for more information.',
PTBRuntimeWarning,
)
memo[obj_id] = obj
return obj
@abstractmethod
def get_user_data(self) -> Dict[int, UD]:
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
@ -628,6 +360,3 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
.. versionchanged:: 14.0
Changed this method into an ``@abstractmethod``.
"""
REPLACED_BOT: ClassVar[str] = 'bot_instance_replaced_by_ptb_persistence'
""":obj:`str`: Placeholder for :class:`telegram.Bot` instances replaced in saved data."""

View file

@ -20,6 +20,8 @@
from typing import Dict, Optional, Tuple, cast
from copy import deepcopy
from telegram.ext import BasePersistence, PersistenceInput
from telegram._utils.types import JSONDict
from telegram.ext._utils.types import ConversationDict, CDCData
@ -31,7 +33,7 @@ except ImportError:
class DictPersistence(BasePersistence):
"""Using Python's :obj:`dict` and ``json`` for making your bot persistent.
"""Using Python's :obj:`dict` and :mod:`json` for making your bot persistent.
Attention:
The interface provided by this class is intended to be accessed exclusively by
@ -39,19 +41,13 @@ class DictPersistence(BasePersistence):
interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`.
Note:
This class does *not* implement a :meth:`flush` method, meaning that data managed by
``DictPersistence`` is in-memory only and will be lost when the bot shuts down. This is,
because ``DictPersistence`` is mainly intended as starting point for custom persistence
classes that need to JSON-serialize the stored data before writing them to file/database.
* Data managed by :class:`DictPersistence` is in-memory only and will be lost when the bot
shuts down. This is, because :class:`DictPersistence` is mainly intended as starting
point for custom persistence classes that need to JSON-serialize the stored data before
writing them to file/database.
Warning:
:class:`DictPersistence` will try to replace :class:`telegram.Bot` instances by
:attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with
:meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure
that changes to the bot apply to the saved objects, too. If you change the bots token, this
may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
:meth:`telegram.ext.BasePersistence.replace_bot` and
:meth:`telegram.ext.BasePersistence.insert_bot`.
* This implementation of :class:`BasePersistence` does not handle data that cannot be
serialized by :func:`json.dumps`.
.. versionchanged:: 14.0
The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
@ -244,7 +240,7 @@ class DictPersistence(BasePersistence):
"""
if self.user_data is None:
self._user_data = {}
return self.user_data # type: ignore[return-value]
return deepcopy(self.user_data) # type: ignore[arg-type]
def get_chat_data(self) -> Dict[int, Dict[object, object]]:
"""Returns the chat_data created from the ``chat_data_json`` or an empty :obj:`dict`.
@ -254,7 +250,7 @@ class DictPersistence(BasePersistence):
"""
if self.chat_data is None:
self._chat_data = {}
return self.chat_data # type: ignore[return-value]
return deepcopy(self.chat_data) # type: ignore[arg-type]
def get_bot_data(self) -> Dict[object, object]:
"""Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`.
@ -264,7 +260,7 @@ class DictPersistence(BasePersistence):
"""
if self.bot_data is None:
self._bot_data = {}
return self.bot_data # type: ignore[return-value]
return deepcopy(self.bot_data) # type: ignore[arg-type]
def get_callback_data(self) -> Optional[CDCData]:
"""Returns the callback_data created from the ``callback_data_json`` or :obj:`None`.
@ -279,7 +275,7 @@ class DictPersistence(BasePersistence):
if self.callback_data is None:
self._callback_data = None
return None
return self.callback_data[0], self.callback_data[1].copy()
return deepcopy((self.callback_data[0], self.callback_data[1].copy()))
def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations created from the ``conversations_json`` or an empty
@ -320,7 +316,7 @@ class DictPersistence(BasePersistence):
self._user_data = {}
if self._user_data.get(user_id) == data:
return
self._user_data[user_id] = data
self._user_data[user_id] = deepcopy(data)
self._user_data_json = None
def update_chat_data(self, chat_id: int, data: Dict) -> None:
@ -334,7 +330,7 @@ class DictPersistence(BasePersistence):
self._chat_data = {}
if self._chat_data.get(chat_id) == data:
return
self._chat_data[chat_id] = data
self._chat_data[chat_id] = deepcopy(data)
self._chat_data_json = None
def update_bot_data(self, data: Dict) -> None:
@ -345,7 +341,7 @@ class DictPersistence(BasePersistence):
"""
if self._bot_data == data:
return
self._bot_data = data
self._bot_data = deepcopy(data)
self._bot_data_json = None
def update_callback_data(self, data: CDCData) -> None:
@ -360,7 +356,7 @@ class DictPersistence(BasePersistence):
"""
if self._callback_data == data:
return
self._callback_data = (data[0], data[1].copy())
self._callback_data = deepcopy((data[0], data[1].copy()))
self._callback_data_json = None
def drop_chat_data(self, chat_id: int) -> None:

View file

@ -17,8 +17,11 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the PicklePersistence class."""
import copyreg
import pickle
from copy import deepcopy
from pathlib import Path
from sys import version_info as py_ver
from typing import (
Any,
Dict,
@ -26,14 +29,105 @@ from typing import (
Tuple,
overload,
cast,
Type,
Set,
Callable,
TypeVar,
)
from telegram import Bot, TelegramObject
from telegram._utils.types import FilePathInput
from telegram._utils.warnings import warn
from telegram.ext import BasePersistence, PersistenceInput
from telegram.ext._contexttypes import ContextTypes
from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData
_REPLACED_KNOWN_BOT = "a known bot replaced by PTB's PicklePersistence"
_REPLACED_UNKNOWN_BOT = "an unknown bot replaced by PTB's PicklePersistence"
TO = TypeVar('TO', bound=TelegramObject)
def _all_subclasses(cls: Type[TO]) -> Set[Type[TO]]:
"""Gets all subclasses of the specified object, recursively. from
https://stackoverflow.com/a/3862957/9706202
"""
subclasses = cls.__subclasses__()
return set(subclasses).union([s for c in subclasses for s in _all_subclasses(c)])
def _reconstruct_to(cls: Type[TO], kwargs: dict) -> TO:
"""
This method is used for unpickling. The data, which is in the form a dictionary, is
converted back into a class. Works mostly the same as :meth:`TelegramObject.__setstate__`.
This function should be kept in place for backwards compatibility even if the pickling logic
is changed, since `_custom_reduction` places references to this function into the pickled data.
"""
obj = cls.__new__(cls)
obj.__setstate__(kwargs)
return obj # type: ignore[return-value]
def _custom_reduction(cls: TO) -> Tuple[Callable, Tuple[Type[TO], dict]]:
"""
This method is used for pickling. The bot attribute is preserved so _BotPickler().persistent_id
works as intended.
"""
data = cls._get_attrs(include_private=True) # pylint: disable=protected-access
return _reconstruct_to, (cls.__class__, data)
class _BotPickler(pickle.Pickler):
__slots__ = ('_bot',)
def __init__(self, bot: Bot, *args: Any, **kwargs: Any):
self._bot = bot
if py_ver < (3, 8): # self.reducer_override is used above this version
# Here we define a private dispatch_table, because we want to preserve the bot
# attribute of objects so persistent_id works as intended. Otherwise, the bot attribute
# is deleted in __getstate__, which is used during regular pickling (via pickle.dumps)
self.dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
for obj in _all_subclasses(TelegramObject):
self.dispatch_table[obj] = _custom_reduction # type: ignore[index]
super().__init__(*args, **kwargs)
def reducer_override( # pylint: disable=no-self-use
self, obj: TO
) -> Tuple[Callable, Tuple[Type[TO], dict]]:
if not isinstance(obj, TelegramObject):
return NotImplemented
return _custom_reduction(obj)
def persistent_id(self, obj: object) -> Optional[str]:
"""Used to 'mark' the Bot, so it can be replaced later. See
https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id for more info
"""
if obj is self._bot:
return _REPLACED_KNOWN_BOT
if isinstance(obj, Bot):
warn('Unknown bot instance found. Will be replaced by `None` during unpickling')
return _REPLACED_UNKNOWN_BOT
return None # pickles as usual
class _BotUnpickler(pickle.Unpickler):
__slots__ = ('_bot',)
def __init__(self, bot: Bot, *args: Any, **kwargs: Any):
self._bot = bot
super().__init__(*args, **kwargs)
def persistent_load(self, pid: str) -> Optional[Bot]:
"""Replaces the bot with the current bot if known, else it is replaced by :obj:`None`."""
if pid == _REPLACED_KNOWN_BOT:
return self._bot
if pid == _REPLACED_UNKNOWN_BOT:
return None
raise pickle.UnpicklingError("Found unknown persistent id when unpickling!")
class PicklePersistence(BasePersistence[UD, CD, BD]):
"""Using python's builtin pickle for making your bot persistent.
@ -42,14 +136,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
:class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might
interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`.
Warning:
:class:`PicklePersistence` will try to replace :class:`telegram.Bot` instances by
:attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with
:meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure
that changes to the bot apply to the saved objects, too. If you change the bots token, this
may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
:meth:`telegram.ext.BasePersistence.replace_bot` and
:meth:`telegram.ext.BasePersistence.insert_bot`.
Note:
This implementation of :class:`BasePersistence` uses the functionality of the pickle module
to support serialization of bot instances. Specifically any reference to
:attr:`~BasePersistence.bot` will be replaced by a placeholder before pickling and
:attr:`~BasePersistence.bot` will be inserted back when loading the data.
.. versionchanged:: 14.0
@ -57,7 +148,6 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
* The parameter and attribute ``filename`` were replaced by :attr:`filepath`.
* :attr:`filepath` now also accepts :obj:`pathlib.Path` as argument.
Args:
filepath (:obj:`str` | :obj:`pathlib.Path`): The filepath for storing the pickle files.
When :attr:`single_file` is :obj:`False` this will be used as a prefix.
@ -151,7 +241,8 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
def _load_singlefile(self) -> None:
try:
with self.filepath.open("rb") as file:
data = pickle.load(file)
data = _BotUnpickler(self.bot, file).load()
self.user_data = data['user_data']
self.chat_data = data['chat_data']
# For backwards compatibility with files not containing bot data
@ -170,11 +261,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
except Exception as exc:
raise TypeError(f"Something went wrong unpickling {self.filepath.name}") from exc
@staticmethod
def _load_file(filepath: Path) -> Any:
def _load_file(self, filepath: Path) -> Any:
try:
with filepath.open("rb") as file:
return pickle.load(file)
return _BotUnpickler(self.bot, file).load()
except OSError:
return None
except pickle.UnpicklingError as exc:
@ -191,12 +282,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
'callback_data': self.callback_data,
}
with self.filepath.open("wb") as file:
pickle.dump(data, file)
_BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data)
@staticmethod
def _dump_file(filepath: Path, data: object) -> None:
def _dump_file(self, filepath: Path, data: object) -> None:
with filepath.open("wb") as file:
pickle.dump(data, file)
_BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data)
def get_user_data(self) -> Dict[int, UD]:
"""Returns the user_data from the pickle file if it exists or an empty :obj:`dict`.
@ -213,7 +303,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self.user_data = data
else:
self._load_singlefile()
return self.user_data # type: ignore[return-value]
return deepcopy(self.user_data) # type: ignore[arg-type]
def get_chat_data(self) -> Dict[int, CD]:
"""Returns the chat_data from the pickle file if it exists or an empty :obj:`dict`.
@ -230,7 +320,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self.chat_data = data
else:
self._load_singlefile()
return self.chat_data # type: ignore[return-value]
return deepcopy(self.chat_data) # type: ignore[arg-type]
def get_bot_data(self) -> BD:
"""Returns the bot_data from the pickle file if it exists or an empty object of type
@ -248,7 +338,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self.bot_data = data
else:
self._load_singlefile()
return self.bot_data # type: ignore[return-value]
return deepcopy(self.bot_data) # type: ignore[return-value]
def get_callback_data(self) -> Optional[CDCData]:
"""Returns the callback data from the pickle file if it exists or :obj:`None`.
@ -271,7 +361,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self._load_singlefile()
if self.callback_data is None:
return None
return self.callback_data[0], self.callback_data[1].copy()
return deepcopy((self.callback_data[0], self.callback_data[1].copy()))
def get_conversations(self, name: str) -> ConversationDict:
"""Returns the conversations from the pickle file if it exists or an empty dict.
@ -326,7 +416,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self.user_data = {}
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
self.user_data[user_id] = deepcopy(data)
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_user_data"), self.user_data)
@ -344,7 +434,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
self.chat_data = {}
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data
self.chat_data[chat_id] = deepcopy(data)
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_chat_data"), self.chat_data)
@ -360,7 +450,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
"""
if self.bot_data == data:
return
self.bot_data = data
self.bot_data = deepcopy(data)
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_bot_data"), self.bot_data)
@ -380,7 +470,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
"""
if self.callback_data == data:
return
self.callback_data = (data[0], data[1].copy())
self.callback_data = deepcopy((data[0], data[1].copy()))
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_callback_data"), self.callback_data)

View file

@ -19,6 +19,7 @@
import datetime
import inspect
import logging
import pickle
import time
import datetime as dtm
from collections import defaultdict
@ -241,6 +242,10 @@ class TestBot:
if bot.last_name:
assert to_dict_bot["last_name"] == bot.last_name
def test_bot_pickling_error(self, bot):
with pytest.raises(pickle.PicklingError, match="Bot objects cannot be pickled"):
pickle.dumps(bot)
@pytest.mark.parametrize(
'bot_method_name',
argvalues=[

View file

@ -24,7 +24,7 @@ from uuid import uuid4
import pytest
import pytz
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User, Chat
from telegram.ext._callbackdatacache import (
CallbackDataCache,
_KeyboardData,
@ -159,7 +159,8 @@ class TestCallbackDataCache:
if invalid:
callback_data_cache.clear_callback_data()
effective_message = Message(message_id=1, date=None, chat=None, reply_markup=out)
chat = Chat(1, 'private')
effective_message = Message(message_id=1, date=datetime.now(), chat=chat, reply_markup=out)
effective_message.reply_to_message = deepcopy(effective_message)
effective_message.pinned_message = deepcopy(effective_message)
cq_id = uuid4().hex

View file

@ -16,26 +16,25 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import datetime
import logging
import os
import pickle
import gzip
import signal
import uuid
from collections.abc import Container
from collections import defaultdict
from pathlib import Path
from time import sleep
from threading import Lock
import pytest
from telegram.warnings import PTBUserWarning
try:
import ujson as json
except ImportError:
import json
from telegram import Update, Message, User, Chat, MessageEntity, Bot
from telegram import Update, Message, User, Chat, MessageEntity, Bot, TelegramObject
from telegram.ext import (
BasePersistence,
ConversationHandler,
@ -130,7 +129,7 @@ def base_persistence():
@pytest.fixture(scope="function")
def bot_persistence():
class BotPersistence(BasePersistence):
__slots__ = ()
__slots__ = ('bot_data', 'chat_data', 'user_data', 'callback_data')
def __init__(self):
super().__init__()
@ -252,7 +251,6 @@ class TestBasePersistence:
inst = bot_persistence
for attr in inst.__slots__:
assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'"
# assert not inst.__dict__, f"got missing slot(s): {inst.__dict__}"
# The below test fails if the child class doesn't define __slots__ (not a cause of concern)
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
@ -590,292 +588,6 @@ class TestBasePersistence:
dp.process_update(MyUpdate())
assert 'An uncaught error was raised while processing the update' not in caplog.text
def test_bot_replace_insert_bot(self, bot, bot_persistence):
class CustomSlottedClass:
__slots__ = ('bot', '__dict__')
def __init__(self):
self.bot = bot
self.not_in_slots = bot
def __eq__(self, other):
if isinstance(other, CustomSlottedClass):
return self.bot is other.bot and self.not_in_slots is other.not_in_slots
return False
class DictNotInSlots(Container):
"""This classes parent has slots, but __dict__ is not in those slots."""
def __init__(self):
self.bot = bot
def __contains__(self, item):
return True
def __eq__(self, other):
if isinstance(other, DictNotInSlots):
return self.bot is other.bot
return False
class CustomClass:
def __init__(self):
self.bot = bot
self.slotted_object = CustomSlottedClass()
self.dict_not_in_slots_object = DictNotInSlots()
self.list_ = [1, 2, bot]
self.tuple_ = tuple(self.list_)
self.set_ = set(self.list_)
self.frozenset_ = frozenset(self.list_)
self.dict_ = {item: item for item in self.list_}
self.defaultdict_ = defaultdict(dict, self.dict_)
@staticmethod
def replace_bot():
cc = CustomClass()
cc.bot = BasePersistence.REPLACED_BOT
cc.slotted_object.bot = BasePersistence.REPLACED_BOT
cc.slotted_object.not_in_slots = BasePersistence.REPLACED_BOT
cc.dict_not_in_slots_object.bot = BasePersistence.REPLACED_BOT
cc.list_ = [1, 2, BasePersistence.REPLACED_BOT]
cc.tuple_ = tuple(cc.list_)
cc.set_ = set(cc.list_)
cc.frozenset_ = frozenset(cc.list_)
cc.dict_ = {item: item for item in cc.list_}
cc.defaultdict_ = defaultdict(dict, cc.dict_)
return cc
def __eq__(self, other):
if isinstance(other, CustomClass):
return (
self.bot is other.bot
and self.slotted_object == other.slotted_object
and self.dict_not_in_slots_object == other.dict_not_in_slots_object
and self.list_ == other.list_
and self.tuple_ == other.tuple_
and self.set_ == other.set_
and self.frozenset_ == other.frozenset_
and self.dict_ == other.dict_
and self.defaultdict_ == other.defaultdict_
)
return False
persistence = bot_persistence
persistence.set_bot(bot)
cc = CustomClass()
persistence.update_bot_data({1: cc})
assert persistence.bot_data[1].bot == BasePersistence.REPLACED_BOT
assert persistence.bot_data[1] == cc.replace_bot()
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1].bot == BasePersistence.REPLACED_BOT
assert persistence.chat_data[123][1] == cc.replace_bot()
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1].bot == BasePersistence.REPLACED_BOT
assert persistence.user_data[123][1] == cc.replace_bot()
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
assert persistence.callback_data[0][0][2][0].bot == BasePersistence.REPLACED_BOT
assert persistence.callback_data[0][0][2][0] == cc.replace_bot()
assert persistence.get_bot_data()[1] == cc
assert persistence.get_bot_data()[1].bot is bot
assert persistence.get_chat_data()[123][1] == cc
assert persistence.get_chat_data()[123][1].bot is bot
assert persistence.get_user_data()[123][1] == cc
assert persistence.get_user_data()[123][1].bot is bot
assert persistence.get_callback_data()[0][0][2][0].bot is bot
assert persistence.get_callback_data()[0][0][2][0] == cc
def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn):
"""Here check that unpickable objects are just returned verbatim."""
persistence = bot_persistence
persistence.set_bot(bot)
class CustomClass:
def __copy__(self):
raise TypeError('UnhandledException')
lock = Lock()
persistence.update_bot_data({1: lock})
assert persistence.bot_data[1] is lock
persistence.update_chat_data(123, {1: lock})
assert persistence.chat_data[123][1] is lock
persistence.update_user_data(123, {1: lock})
assert persistence.user_data[123][1] is lock
persistence.update_callback_data(([('1', 2, {0: lock})], {'1': '2'}))
assert persistence.callback_data[0][0][2][0] is lock
assert persistence.get_bot_data()[1] is lock
assert persistence.get_chat_data()[123][1] is lock
assert persistence.get_user_data()[123][1] is lock
assert persistence.get_callback_data()[0][0][2][0] is lock
cc = CustomClass()
persistence.update_bot_data({1: cc})
assert persistence.bot_data[1] is cc
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1] is cc
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1] is cc
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
assert persistence.callback_data[0][0][2][0] is cc
assert persistence.get_bot_data()[1] is cc
assert persistence.get_chat_data()[123][1] is cc
assert persistence.get_user_data()[123][1] is cc
assert persistence.get_callback_data()[0][0][2][0] is cc
assert len(recwarn) == 2
assert str(recwarn[0].message).startswith(
"BasePersistence.replace_bot does not handle objects that can not be copied."
)
assert str(recwarn[1].message).startswith(
"BasePersistence.insert_bot does not handle objects that can not be copied."
)
def test_bot_replace_insert_bot_unparsable_objects(self, bot, bot_persistence, recwarn):
"""Here check that objects in __dict__ or __slots__ that can't
be parsed are just returned verbatim."""
persistence = bot_persistence
persistence.set_bot(bot)
uuid_obj = uuid.uuid4()
persistence.update_bot_data({1: uuid_obj})
assert persistence.bot_data[1] is uuid_obj
persistence.update_chat_data(123, {1: uuid_obj})
assert persistence.chat_data[123][1] is uuid_obj
persistence.update_user_data(123, {1: uuid_obj})
assert persistence.user_data[123][1] is uuid_obj
persistence.update_callback_data(([('1', 2, {0: uuid_obj})], {'1': '2'}))
assert persistence.callback_data[0][0][2][0] is uuid_obj
assert persistence.get_bot_data()[1] is uuid_obj
assert persistence.get_chat_data()[123][1] is uuid_obj
assert persistence.get_user_data()[123][1] is uuid_obj
assert persistence.get_callback_data()[0][0][2][0] is uuid_obj
assert len(recwarn) == 2
assert str(recwarn[0].message).startswith(
"Parsing of an object failed with the following exception: "
)
assert str(recwarn[1].message).startswith(
"Parsing of an object failed with the following exception: "
)
def test_bot_replace_insert_bot_classes(self, bot, bot_persistence, recwarn):
"""Here check that classes are just returned verbatim."""
persistence = bot_persistence
persistence.set_bot(bot)
class CustomClass:
pass
persistence.update_bot_data({1: CustomClass})
assert persistence.bot_data[1] is CustomClass
persistence.update_chat_data(123, {1: CustomClass})
assert persistence.chat_data[123][1] is CustomClass
persistence.update_user_data(123, {1: CustomClass})
assert persistence.user_data[123][1] is CustomClass
assert persistence.get_bot_data()[1] is CustomClass
assert persistence.get_chat_data()[123][1] is CustomClass
assert persistence.get_user_data()[123][1] is CustomClass
assert len(recwarn) == 2
assert str(recwarn[0].message).startswith(
"BasePersistence.replace_bot does not handle classes such as 'CustomClass'"
)
assert str(recwarn[1].message).startswith(
"BasePersistence.insert_bot does not handle classes such as 'CustomClass'"
)
def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence):
"""Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems."""
persistence = bot_persistence
persistence.set_bot(bot)
class CustomClass:
def __init__(self, data):
self.data = data
def __eq__(self, other):
raise RuntimeError("Can't be compared")
cc = CustomClass({1: bot, 2: 'foo'})
expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'}
persistence.update_bot_data({1: cc})
assert persistence.bot_data[1].data == expected
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1].data == expected
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1].data == expected
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
assert persistence.callback_data[0][0][2][0].data == expected
expected = {1: bot, 2: 'foo'}
assert persistence.get_bot_data()[1].data == expected
assert persistence.get_chat_data()[123][1].data == expected
assert persistence.get_user_data()[123][1].data == expected
assert persistence.get_callback_data()[0][0][2][0].data == expected
@pytest.mark.filterwarnings('ignore:BasePersistence')
def test_replace_insert_bot_item_identity(self, bot, bot_persistence):
persistence = bot_persistence
persistence.set_bot(bot)
class CustomSlottedClass:
__slots__ = ('value',)
def __init__(self):
self.value = 5
class CustomClass:
pass
slot_object = CustomSlottedClass()
dict_object = CustomClass()
lock = Lock()
list_ = [slot_object, dict_object, lock]
tuple_ = (1, 2, 3)
dict_ = {1: slot_object, 2: dict_object}
data = {
'bot_1': bot,
'bot_2': bot,
'list_1': list_,
'list_2': list_,
'tuple_1': tuple_,
'tuple_2': tuple_,
'dict_1': dict_,
'dict_2': dict_,
}
def make_assertion(data_):
return (
data_['bot_1'] is data_['bot_2']
and data_['list_1'] is data_['list_2']
and data_['list_1'][0] is data_['list_2'][0]
and data_['list_1'][1] is data_['list_2'][1]
and data_['list_1'][2] is data_['list_2'][2]
and data_['tuple_1'] is data_['tuple_2']
and data_['dict_1'] is data_['dict_2']
and data_['dict_1'][1] is data_['dict_2'][1]
and data_['dict_1'][1] is data_['list_1'][0]
and data_['dict_1'][2] is data_['list_1'][1]
and data_['dict_1'][2] is data_['dict_2'][2]
)
persistence.update_bot_data(data)
assert make_assertion(persistence.bot_data)
assert make_assertion(persistence.get_bot_data())
def test_set_bot_exception(self, bot):
non_ext_bot = Bot(bot.token)
persistence = OwnPersistence()
@ -1033,11 +745,28 @@ def pickle_files_wo_callback_data(user_data, chat_data, bot_data, conversations)
def update(bot):
user = User(id=321, first_name='test_user', is_bot=False)
chat = Chat(id=123, type='group')
message = Message(1, None, chat, from_user=user, text="Hi there", bot=bot)
message = Message(1, datetime.datetime.now(), chat, from_user=user, text="Hi there", bot=bot)
return Update(0, message=message)
class TestPicklePersistence:
class DictSub(TelegramObject): # Used for testing our custom (Un)Pickler.
def __init__(self, private, normal, b):
self._private = private
self.normal = normal
self._bot = b
class SlotsSub(TelegramObject):
__slots__ = ('new_var', '_private')
def __init__(self, new_var, private):
self.new_var = new_var
self._private = private
class NormalClass:
def __init__(self, my_var):
self.my_var = my_var
def test_slot_behaviour(self, mro_slots, pickle_persistence):
inst = pickle_persistence
for attr in inst.__slots__:
@ -1046,7 +775,7 @@ class TestPicklePersistence:
def test_pickle_behaviour_with_slots(self, pickle_persistence):
bot_data = pickle_persistence.get_bot_data()
bot_data['message'] = Message(3, None, Chat(2, type='supergroup'))
bot_data['message'] = Message(3, datetime.datetime.now(), Chat(2, type='supergroup'))
pickle_persistence.update_bot_data(bot_data)
retrieved = pickle_persistence.get_bot_data()
assert retrieved == bot_data
@ -1633,13 +1362,113 @@ class TestPicklePersistence:
conversations_test = dict(pickle.load(f))['conversations']
assert conversations_test['name1'] == conversation1
def test_custom_pickler_unpickler_simple(
self, pickle_persistence, update, good_pickle_files, bot, recwarn
):
pickle_persistence.bot = bot # assign the current bot to the persistence
data_with_bot = {'current_bot': update.message}
pickle_persistence.update_chat_data(12345, data_with_bot) # also calls BotPickler.dumps()
# Test that regular pickle load fails -
err_msg = (
"A load persistent id instruction was encountered,\nbut no persistent_load "
"function was specified."
)
with pytest.raises(pickle.UnpicklingError, match=err_msg):
with open('pickletest_chat_data', 'rb') as f:
pickle.load(f)
# Test that our custom unpickler works as intended -- inserts the current bot
# We have to create a new instance otherwise unpickling is skipped
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
pp.bot = bot # Set the bot
assert pp.get_chat_data()[12345]['current_bot'].get_bot() is bot
# Now test that pickling of unknown bots in TelegramObjects will be replaced by None-
assert not len(recwarn)
data_with_bot['unknown_bot_in_user'] = User(1, 'Dev', False, bot=Bot('1234:abcd'))
pickle_persistence.update_chat_data(12345, data_with_bot)
assert len(recwarn) == 1
assert recwarn[-1].category is PTBUserWarning
assert str(recwarn[-1].message).startswith("Unknown bot instance found.")
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
pp.bot = bot
assert pp.get_chat_data()[12345]['unknown_bot_in_user']._bot is None
def test_custom_pickler_unpickler_with_custom_objects(
self, bot, pickle_persistence, good_pickle_files
):
dict_s = self.DictSub("private", 'normal', bot)
slot_s = self.SlotsSub("new_var", 'private_var')
regular = self.NormalClass(12)
pickle_persistence.bot = bot
pickle_persistence.update_user_data(
1232, {'sub_dict': dict_s, 'sub_slots': slot_s, 'r': regular}
)
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
pp.bot = bot # Set the bot
data = pp.get_user_data()[1232]
sub_dict = data['sub_dict']
sub_slots = data['sub_slots']
sub_regular = data['r']
assert sub_dict._bot is bot
assert sub_dict.normal == dict_s.normal
assert sub_dict._private == dict_s._private
assert sub_slots.new_var == slot_s.new_var
assert sub_slots._private == slot_s._private
assert sub_slots._bot is None # We didn't set the bot, so it shouldn't have it here.
assert sub_regular.my_var == regular.my_var
def test_custom_pickler_unpickler_with_handler_integration(
self, bot, update, pickle_persistence, good_pickle_files, recwarn
):
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build()
dp = u.dispatcher
bot_id = None
def first(update, context):
nonlocal bot_id
bot_id = update.message.get_bot()
# Test pickling a message object, which has the current bot
context.user_data['msg'] = update.message
# Test pickling a bot, which is not known. Directly serializing bots will fail.
new_chat = Chat(1, 'private', bot=Bot('1234:abcd'))
context.chat_data['unknown_bot_in_chat'] = new_chat
def second(_, context):
msg = context.user_data['msg']
assert bot_id is msg.get_bot() # Tests if the same bot is inserted by the unpickler
new_none_bot = context.chat_data['unknown_bot_in_chat']._bot
assert new_none_bot is None
h1 = MessageHandler(None, first)
h2 = MessageHandler(None, second)
dp.add_handler(h1)
assert not len(recwarn)
dp.process_update(update)
assert len(recwarn) == 1
assert recwarn[-1].category is PTBUserWarning
assert str(recwarn[-1].message).startswith("Unknown bot instance found.")
pickle_persistence_2 = PicklePersistence( # initialize a new persistence for unpickling
filepath='pickletest',
single_file=False,
on_flush=False,
)
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_2).build()
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)
def test_with_handler(self, bot, update, bot_data, pickle_persistence, good_pickle_files):
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build()
dp = u.dispatcher
bot.callback_data_cache.clear_callback_data()
bot.callback_data_cache.clear_callback_queries()
def first(update, context):
def first(_, context):
if not context.user_data == {}:
pytest.fail()
if not context.chat_data == {}:
@ -1653,7 +1482,7 @@ class TestPicklePersistence:
context.bot_data['test1'] = 'test0'
context.bot.callback_data_cache._callback_queries['test1'] = 'test0'
def second(update, context):
def second(_, context):
if not context.user_data['test1'] == 'test2':
pytest.fail()
if not context.chat_data['test3'] == 'test4':
@ -1813,7 +1642,7 @@ class TestPicklePersistence:
assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == pickle_persistence.conversations['name2']
def test_with_nested_conversationHandler(
def test_with_nested_conversation_handler(
self, dp, update, good_pickle_files, pickle_persistence
):
dp.persistence = pickle_persistence

View file

@ -25,7 +25,6 @@ import inspect
included = { # These modules/classes intentionally have __dict__.
'CallbackContext',
'BasePersistence',
}

View file

@ -16,20 +16,28 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import datetime
import json as json_lib
import pickle
import pytest
from copy import deepcopy
try:
import ujson
except ImportError:
ujson = None
from telegram import TelegramObject, Message, Chat, User
from telegram import TelegramObject, Message, Chat, User, PhotoSize
class TestTelegramObject:
class Sub(TelegramObject):
def __init__(self, private, normal, b):
self._private = private
self.normal = normal
self._bot = b
def test_to_json_native(self, monkeypatch):
if ujson:
monkeypatch.setattr('ujson.dumps', json_lib.dumps)
@ -145,3 +153,47 @@ class TestTelegramObject:
assert message['from_user'] is user
with pytest.raises(KeyError, match="Message don't have an attribute called `no_key`"):
message['no_key']
def test_pickle(self, bot):
chat = Chat(2, Chat.PRIVATE)
user = User(3, 'first_name', False)
date = datetime.datetime.now()
photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot)
msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo])
# Test pickling of TGObjects, we choose Message since it's contains the most subclasses.
assert msg.get_bot()
unpickled = pickle.loads(pickle.dumps(msg))
with pytest.raises(RuntimeError):
unpickled.get_bot() # There should be no bot when we pickle TGObjects
assert unpickled.chat == chat
assert unpickled.from_user == user
assert unpickled.date == date
assert unpickled.photo[0] == photo
def test_deepcopy_telegram_obj(self, bot):
chat = Chat(2, Chat.PRIVATE)
user = User(3, 'first_name', False)
date = datetime.datetime.now()
photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot)
msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo])
new_msg = deepcopy(msg)
# The same bot should be present when deepcopying.
assert new_msg.get_bot() == bot and new_msg.get_bot() is bot
assert new_msg.date == date and new_msg.date is not date
assert new_msg.chat == chat and new_msg.chat is not chat
assert new_msg.from_user == user and new_msg.from_user is not user
assert new_msg.photo[0] == photo and new_msg.photo[0] is not photo
def test_deepcopy_subclass_telegram_obj(self, bot):
s = self.Sub("private", 'normal', bot)
d = deepcopy(s)
assert d is not s
assert d._private == s._private # Can't test for identity since two equal strings is True
assert d._bot == s._bot and d._bot is s._bot
assert d.normal == s.normal