diff --git a/telegram/_bot.py b/telegram/_bot.py index 392dd480f..74c9e96f3 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -22,6 +22,7 @@ import functools import logging +import pickle from datetime import datetime from typing import ( @@ -37,6 +38,7 @@ from typing import ( cast, Sequence, Any, + NoReturn, ) try: @@ -123,10 +125,13 @@ class Bot(TelegramObject): considered equal, if their :attr:`bot` is equal. Note: - Most bot methods have the argument ``api_kwargs`` which allows to pass arbitrary keywords - to the Telegram API. This can be used to access new features of the API before they were - incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for - passing files. + * Most bot methods have the argument ``api_kwargs`` which allows passing arbitrary keywords + to the Telegram API. This can be used to access new features of the API before they are + incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for + passing files. + * Bots should not be serialized since if you for e.g. change the bots token, then your + serialized instance will not reflect that change. Trying to pickle a bot instance will + raise :exc:`pickle.PicklingError`. .. versionchanged:: 14.0 @@ -136,6 +141,7 @@ class Bot(TelegramObject): * Removed the deprecated ``defaults`` parameter. If you want to use :class:`telegram.ext.Defaults`, please use the subclass :class:`telegram.ext.ExtBot` instead. + * Attempting to pickle a bot instance will now raise :exc:`pickle.PicklingError`. Args: token (:obj:`str`): Bot's unique authentication. @@ -157,7 +163,7 @@ class Bot(TelegramObject): 'private_key', '_bot_user', '_request', - 'logger', + '_logger', ) def __init__( @@ -176,7 +182,7 @@ class Bot(TelegramObject): self._bot_user: Optional[User] = None self._request = request or Request() self.private_key = None - self.logger = logging.getLogger(__name__) + self._logger = logging.getLogger(__name__) if private_key: if not CRYPTO_INSTALLED: @@ -188,6 +194,10 @@ class Bot(TelegramObject): private_key, password=private_key_password, backend=default_backend() ) + def __reduce__(self) -> NoReturn: + """Called by pickle.dumps(). Serializing bots is unadvisable, so we forbid pickling.""" + raise pickle.PicklingError('Bot objects cannot be pickled!') + # TODO: After https://youtrack.jetbrains.com/issue/PY-50952 is fixed, we can revisit this and # consider adding Paramspec from typing_extensions to properly fix this. Currently a workaround def _log(func: Any): # type: ignore[no-untyped-def] # skipcq: PY-D0003 @@ -2999,9 +3009,9 @@ class Bot(TelegramObject): ) if result: - self.logger.debug('Getting updates: %s', [u['update_id'] for u in result]) + self._logger.debug('Getting updates: %s', [u['update_id'] for u in result]) else: - self.logger.debug('No new updates found.') + self._logger.debug('No new updates found.') return Update.de_list(result, self) # type: ignore[return-value] diff --git a/telegram/_files/_basemedium.py b/telegram/_files/_basemedium.py index 06e9a485c..c1421074e 100644 --- a/telegram/_files/_basemedium.py +++ b/telegram/_files/_basemedium.py @@ -51,7 +51,7 @@ class _BaseMedium(TelegramObject): """ - __slots__ = ('bot', 'file_id', 'file_size', 'file_unique_id') + __slots__ = ('file_id', 'file_size', 'file_unique_id') def __init__( self, file_id: str, file_unique_id: str, file_size: int = None, bot: 'Bot' = None diff --git a/telegram/_passport/credentials.py b/telegram/_passport/credentials.py index 478f3f9a5..b0879fa92 100644 --- a/telegram/_passport/credentials.py +++ b/telegram/_passport/credentials.py @@ -210,14 +210,14 @@ class Credentials(TelegramObject): nonce (:obj:`str`): Bot-specified nonce """ - __slots__ = ('bot', 'nonce', 'secure_data') + __slots__ = ('nonce', 'secure_data') def __init__(self, secure_data: 'SecureData', nonce: str, bot: 'Bot' = None, **_kwargs: Any): # Required self.secure_data = secure_data self.nonce = nonce - self.bot = bot + self.set_bot(bot) @classmethod def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['Credentials']: @@ -261,7 +261,6 @@ class SecureData(TelegramObject): """ __slots__ = ( - 'bot', 'utility_bill', 'personal_details', 'temporary_registration', @@ -304,7 +303,7 @@ class SecureData(TelegramObject): self.passport = passport self.personal_details = personal_details - self.bot = bot + self.set_bot(bot) @classmethod def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureData']: @@ -360,7 +359,7 @@ class SecureValue(TelegramObject): """ - __slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation', 'bot') + __slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation') def __init__( self, @@ -380,7 +379,7 @@ class SecureValue(TelegramObject): self.files = files self.translation = translation - self.bot = bot + self.set_bot(bot) @classmethod def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureValue']: @@ -412,17 +411,17 @@ class SecureValue(TelegramObject): class _CredentialsBase(TelegramObject): """Base class for DataCredentials and FileCredentials.""" - __slots__ = ('hash', 'secret', 'file_hash', 'data_hash', 'bot') + __slots__ = ('hash', 'secret', 'file_hash', 'data_hash') def __init__(self, hash: str, secret: str, bot: 'Bot' = None, **_kwargs: Any): self.hash = hash self.secret = secret - # Aliases just be be sure + # Aliases just to be sure self.file_hash = self.hash self.data_hash = self.hash - self.bot = bot + self.set_bot(bot) class DataCredentials(_CredentialsBase): diff --git a/telegram/_telegramobject.py b/telegram/_telegramobject.py index e70bb9c2d..e92434295 100644 --- a/telegram/_telegramobject.py +++ b/telegram/_telegramobject.py @@ -17,12 +17,14 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """Base class for Telegram Objects.""" +from copy import deepcopy + try: import ujson as json except ImportError: import json # type: ignore[no-redef] -from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple, Dict, Union from telegram._utils.types import JSONDict from telegram._utils.warnings import warn @@ -40,6 +42,12 @@ class TelegramObject: is equivalent to ``telegram_object.attribute_name``. If the object does not have an attribute with the appropriate name, a :exc:`KeyError` will be raised. + When objects of this type are pickled, the :class:`~telegram.Bot` attribute associated with the + object will be removed. However, when copying the object via :func:`copy.deepcopy`, the copy + will have the *same* bot instance associated with it, i.e:: + + assert telegram_object.get_bot() is copy.deepcopy(telegram_object).get_bot() + .. versionchanged:: 14.0 ``telegram_object['from']`` will look up the key ``from_user``. This is to account for special cases like :attr:`Message.from_user` that deviate from the official Bot API. @@ -53,15 +61,12 @@ class TelegramObject: _bot: Optional['Bot'] # Adding slots reduces memory usage & allows for faster attribute access. # Only instance variables should be added to __slots__. - __slots__ = ( - '_id_attrs', - '_bot', - ) + __slots__ = ('_id_attrs', '_bot') # pylint: disable=unused-argument def __new__(cls, *args: object, **kwargs: object) -> 'TelegramObject': # We add _id_attrs in __new__ instead of __init__ since we want to add this to the slots - # w/o calling __init__ in all of the subclasses. This is what we also do in BaseFilter. + # w/o calling __init__ in all of the subclasses. instance = super().__new__(cls) instance._id_attrs = () instance._bot = None @@ -81,6 +86,86 @@ class TelegramObject: f"`{item}`." ) from exc + def __getstate__(self) -> Dict[str, Union[str, object]]: + """ + This method is used for pickling. We remove the bot attribute of the object since those + are not pickable. + """ + return self._get_attrs(include_private=True, recursive=False, remove_bot=True) + + def __setstate__(self, state: dict) -> None: + """ + This method is used for unpickling. The data, which is in the form a dictionary, is + converted back into a class. Should be modified in place. + """ + for key, val in state.items(): + setattr(self, key, val) + + def __deepcopy__(self: TO, memodict: dict) -> TO: + """This method deepcopies the object and sets the bot on the newly created copy.""" + bot = self._bot # Save bot so we can set it after copying + self.set_bot(None) # set to None so it is not deepcopied + cls = self.__class__ + result = cls.__new__(cls) # create a new instance + memodict[id(self)] = result # save the id of the object in the dict + + attrs = self._get_attrs(include_private=True) # get all its attributes + + for k in attrs: # now we set the attributes in the deepcopied object + setattr(result, k, deepcopy(getattr(self, k), memodict)) + + result.set_bot(bot) # Assign the bots back + self.set_bot(bot) + return result # type: ignore[return-value] + + def _get_attrs( + self, + include_private: bool = False, + recursive: bool = False, + remove_bot: bool = False, + ) -> Dict[str, Union[str, object]]: + """This method is used for obtaining the attributes of the object. + + Args: + include_private (:obj:`bool`): Whether the result should include private variables. + recursive (:obj:`bool`): If :obj:`True`, will convert any TelegramObjects (if found) in + the attributes to a dictionary. Else, preserves it as an object itself. + remove_bot (:obj:`bool`): Whether the bot should be included in the result. + + Returns: + :obj:`dict`: A dict where the keys are attribute names and values are their values. + """ + data = {} + + if not recursive: + try: + # __dict__ has attrs from superclasses, so no need to put in the for loop below + data.update(self.__dict__) + except AttributeError: + pass + # We want to get all attributes for the class, using self.__slots__ only includes the + # attributes used by that class itself, and not its superclass(es). Hence, we get its MRO + # and then get their attributes. The `[:-1]` slice excludes the `object` class + for cls in self.__class__.__mro__[:-1]: + for key in cls.__slots__: + if not include_private and key.startswith('_'): + continue + + value = getattr(self, key, None) + if value is not None: + if recursive and hasattr(value, 'to_dict'): + data[key] = value.to_dict() + else: + data[key] = value + elif not recursive: + data[key] = value + + if recursive and data.get('from_user'): + data['from'] = data.pop('from_user', None) + if remove_bot: + data.pop('_bot', None) + return data + @staticmethod def _parse_data(data: Optional[JSONDict]) -> Optional[JSONDict]: return None if data is None else data.copy() @@ -137,27 +222,7 @@ class TelegramObject: Returns: :obj:`dict` """ - data = {} - - # We want to get all attributes for the class, using self.__slots__ only includes the - # attributes used by that class itself, and not its superclass(es). Hence we get its MRO - # and then get their attributes. The `[:-2]` slice excludes the `object` class & the - # TelegramObject class itself. - attrs = {attr for cls in self.__class__.__mro__[:-2] for attr in cls.__slots__} - for key in attrs: - if key == 'bot' or key.startswith('_'): - continue - - value = getattr(self, key, None) - if value is not None: - if hasattr(value, 'to_dict'): - data[key] = value.to_dict() - else: - data[key] = value - - if data.get('from_user'): - data['from'] = data.pop('from_user', None) - return data + return self._get_attrs(recursive=True) def get_bot(self) -> 'Bot': """Returns the :class:`telegram.Bot` instance associated with this object. @@ -171,8 +236,7 @@ class TelegramObject: """ if self._bot is None: raise RuntimeError( - 'This object has no bot associated with it. \ - Shortcuts cannot be used.' + 'This object has no bot associated with it. Shortcuts cannot be used.' ) return self._bot diff --git a/telegram/ext/_basepersistence.py b/telegram/ext/_basepersistence.py index 22160c322..8cf0ca977 100644 --- a/telegram/ext/_basepersistence.py +++ b/telegram/ext/_basepersistence.py @@ -18,14 +18,11 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the BasePersistence class.""" from abc import ABC, abstractmethod -from copy import copy -from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, NamedTuple +from typing import Dict, Optional, Tuple, Generic, NamedTuple from telegram import Bot from telegram.ext import ExtBot -from telegram.warnings import PTBRuntimeWarning -from telegram._utils.warnings import warn from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData @@ -73,33 +70,39 @@ class BasePersistence(Generic[UD, CD, BD], ABC): * :meth:`get_chat_data` * :meth:`update_chat_data` * :meth:`refresh_chat_data` + * :meth:`drop_chat_data` * :meth:`get_user_data` * :meth:`update_user_data` * :meth:`refresh_user_data` + * :meth:`drop_user_data` * :meth:`get_callback_data` * :meth:`update_callback_data` * :meth:`get_conversations` * :meth:`update_conversation` * :meth:`flush` - If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if - you don't store ``bot_data``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or - :meth:`refresh_bot_data`. - - Warning: - Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and - insert the bot set with :meth:`set_bot` upon loading of the data. This is to ensure that - changes to the bot apply to the saved objects, too. If you change the bots token, this may - lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see - :meth:`replace_bot` and :meth:`insert_bot`. + If you don't actually need one of those methods, a simple :keyword:`pass` is enough. + For example, if you don't store ``bot_data``, you don't need :meth:`get_bot_data`, + :meth:`update_bot_data` or :meth:`refresh_bot_data`. Note: - :meth:`replace_bot` and :meth:`insert_bot` are used *independently* of the implementation - of the ``update/get_*`` methods, i.e. you don't need to worry about it while - implementing a custom persistence subclass. + You should avoid saving :class:`telegram.Bot` instances. This is because if you change e.g. + the bots token, this won't propagate to the serialized instances and may lead to exceptions. + + To prevent this, the implementation may use :attr:`bot` to replace bot instances with a + placeholder before serialization and insert :attr:`bot` back when loading the data. + Since :attr:`bot` will be set when the process starts, this will be the up-to-date bot + instance. + + If the persistence implementation does not take care of this, you should make sure not to + store any bot instances in the data that will be persisted. E.g. in case of + :class:`telegram.TelegramObject`, one may call :meth:`set_bot` to ensure that shortcuts like + :meth:`telegram.Message.reply_text` are available. .. versionchanged:: 14.0 - The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`. + * The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`. + * ``insert/replace_bot`` was dropped. Serialization of bot instances now needs to be + handled by the specific implementation - see above note. Args: store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be @@ -109,72 +112,10 @@ class BasePersistence(Generic[UD, CD, BD], ABC): Attributes: store_data (:class:`PersistenceInput`): Specifies which kinds of data will be saved by this persistence instance. + bot (:class:`telegram.Bot`): The bot associated with the persistence. """ - __slots__ = ( - 'bot', - 'store_data', - '__dict__', # __dict__ is included because we replace methods in the __new__ - ) - - def __new__( - cls, *args: object, **kwargs: object # pylint: disable=unused-argument - ) -> 'BasePersistence': - """This overrides the get_* and update_* methods to use insert/replace_bot. - That has the side effect that we always pass deepcopied data to those methods, so in - Pickle/DictPersistence we don't have to worry about copying the data again. - - Note: This doesn't hold for second tuple-entry of callback_data. That's a Dict[str, str], - so no bots to replace anyway. - """ - instance = super().__new__(cls) - get_user_data = instance.get_user_data - get_chat_data = instance.get_chat_data - get_bot_data = instance.get_bot_data - get_callback_data = instance.get_callback_data - update_user_data = instance.update_user_data - update_chat_data = instance.update_chat_data - update_bot_data = instance.update_bot_data - update_callback_data = instance.update_callback_data - - def get_user_data_insert_bot() -> Dict[int, UD]: - return instance.insert_bot(get_user_data()) - - def get_chat_data_insert_bot() -> Dict[int, CD]: - return instance.insert_bot(get_chat_data()) - - def get_bot_data_insert_bot() -> BD: - return instance.insert_bot(get_bot_data()) - - def get_callback_data_insert_bot() -> Optional[CDCData]: - cdc_data = get_callback_data() - if cdc_data is None: - return None - return instance.insert_bot(cdc_data[0]), cdc_data[1] - - def update_user_data_replace_bot(user_id: int, data: UD) -> None: - return update_user_data(user_id, instance.replace_bot(data)) - - def update_chat_data_replace_bot(chat_id: int, data: CD) -> None: - return update_chat_data(chat_id, instance.replace_bot(data)) - - def update_bot_data_replace_bot(data: BD) -> None: - return update_bot_data(instance.replace_bot(data)) - - def update_callback_data_replace_bot(data: CDCData) -> None: - obj_data, queue = data - return update_callback_data((instance.replace_bot(obj_data), queue)) - - # Adds to __dict__ - setattr(instance, 'get_user_data', get_user_data_insert_bot) - setattr(instance, 'get_chat_data', get_chat_data_insert_bot) - setattr(instance, 'get_bot_data', get_bot_data_insert_bot) - setattr(instance, 'get_callback_data', get_callback_data_insert_bot) - setattr(instance, 'update_user_data', update_user_data_replace_bot) - setattr(instance, 'update_chat_data', update_chat_data_replace_bot) - setattr(instance, 'update_bot_data', update_bot_data_replace_bot) - setattr(instance, 'update_callback_data', update_callback_data_replace_bot) - return instance + __slots__ = ('bot', 'store_data') def __init__(self, store_data: PersistenceInput = None): self.store_data = store_data or PersistenceInput() @@ -192,215 +133,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC): self.bot = bot - @classmethod - def replace_bot(cls, obj: object) -> object: - """ - Replaces all instances of :class:`telegram.Bot` that occur within the passed object with - :attr:`REPLACED_BOT`. Currently, this handles objects of type :class:`list`, - :class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`, - :class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or - :data:`~object.__slots__` attribute, excluding classes and objects that can't be copied - with :func:`copy.copy`. If the parsing of an object fails, the object will be returned - unchanged and the error will be logged. - - Args: - obj (:obj:`object`): The object - - Returns: - :class:`object`: Copy of the object with Bot instances replaced. - """ - return cls._replace_bot(obj, {}) - - @classmethod - def _replace_bot( # pylint: disable=too-many-return-statements - cls, obj: object, memo: Dict[int, object] - ) -> object: - obj_id = id(obj) - if obj_id in memo: - return memo[obj_id] - - if isinstance(obj, Bot): - memo[obj_id] = cls.REPLACED_BOT - return cls.REPLACED_BOT - if isinstance(obj, (list, set)): - # We copy the iterable here for thread safety, i.e. make sure the object we iterate - # over doesn't change its length during the iteration - temp_iterable = obj.copy() - new_iterable = obj.__class__(cls._replace_bot(item, memo) for item in temp_iterable) - memo[obj_id] = new_iterable - return new_iterable - if isinstance(obj, (tuple, frozenset)): - # tuples and frozensets are immutable so we don't need to worry about thread safety - new_immutable = obj.__class__(cls._replace_bot(item, memo) for item in obj) - memo[obj_id] = new_immutable - return new_immutable - if isinstance(obj, type): - # classes usually do have a __dict__, but it's not writable - warn( - f'BasePersistence.replace_bot does not handle classes such as {obj.__name__!r}. ' - 'See the docs of BasePersistence.replace_bot for more information.', - PTBRuntimeWarning, - ) - return obj - - try: - new_obj = copy(obj) - memo[obj_id] = new_obj - except Exception: - warn( - 'BasePersistence.replace_bot does not handle objects that can not be copied. See ' - 'the docs of BasePersistence.replace_bot for more information.', - PTBRuntimeWarning, - ) - memo[obj_id] = obj - return obj - - if isinstance(obj, dict): - # We handle dicts via copy(obj) so we don't have to make a - # difference between dict and defaultdict - new_obj = cast(dict, new_obj) - # We can't iterate over obj.items() due to thread safety, i.e. the dicts length may - # change during the iteration - temp_dict = new_obj.copy() - new_obj.clear() - for k, val in temp_dict.items(): - new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo) - memo[obj_id] = new_obj - return new_obj - try: - if hasattr(obj, '__slots__'): - for attr_name in new_obj.__slots__: - setattr( - new_obj, - attr_name, - cls._replace_bot( - cls._replace_bot(getattr(new_obj, attr_name), memo), memo - ), - ) - if '__dict__' in obj.__slots__: - # In this case, we have already covered the case that obj has __dict__ - # Note that obj may have a __dict__ even if it's not in __slots__! - memo[obj_id] = new_obj - return new_obj - if hasattr(obj, '__dict__'): - for attr_name, attr in new_obj.__dict__.items(): - setattr(new_obj, attr_name, cls._replace_bot(attr, memo)) - memo[obj_id] = new_obj - return new_obj - except Exception as exception: - warn( - f'Parsing of an object failed with the following exception: {exception}. ' - f'See the docs of BasePersistence.replace_bot for more information.', - PTBRuntimeWarning, - ) - - memo[obj_id] = obj - return obj - - def insert_bot(self, obj: object) -> object: - """ - Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with - :paramref:`bot`. Currently, this handles objects of type :class:`list`, - :class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`, - :class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or - :data:`~object.__slots__` attribute, excluding classes and objects that can't be copied - with :func:`copy.copy`. If the parsing of an object fails, the object will be returned - unchanged and the error will be logged. - - Args: - obj (:obj:`object`): The object - - Returns: - :class:`object`: Copy of the object with Bot instances inserted. - """ - return self._insert_bot(obj, {}) - - # pylint: disable=too-many-return-statements - def _insert_bot(self, obj: object, memo: Dict[int, object]) -> object: - obj_id = id(obj) - if obj_id in memo: - return memo[obj_id] - - if isinstance(obj, Bot): - memo[obj_id] = self.bot - return self.bot - if isinstance(obj, str) and obj == self.REPLACED_BOT: - memo[obj_id] = self.bot - return self.bot - if isinstance(obj, (list, set)): - # We copy the iterable here for thread safety, i.e. make sure the object we iterate - # over doesn't change its length during the iteration - temp_iterable = obj.copy() - new_iterable = obj.__class__(self._insert_bot(item, memo) for item in temp_iterable) - memo[obj_id] = new_iterable - return new_iterable - if isinstance(obj, (tuple, frozenset)): - # tuples and frozensets are immutable so we don't need to worry about thread safety - new_immutable = obj.__class__(self._insert_bot(item, memo) for item in obj) - memo[obj_id] = new_immutable - return new_immutable - if isinstance(obj, type): - # classes usually do have a __dict__, but it's not writable - warn( - f'BasePersistence.insert_bot does not handle classes such as {obj.__name__!r}. ' - 'See the docs of BasePersistence.insert_bot for more information.', - PTBRuntimeWarning, - ) - return obj - - try: - new_obj = copy(obj) - except Exception: - warn( - 'BasePersistence.insert_bot does not handle objects that can not be copied. See ' - 'the docs of BasePersistence.insert_bot for more information.', - PTBRuntimeWarning, - ) - memo[obj_id] = obj - return obj - - if isinstance(obj, dict): - # We handle dicts via copy(obj) so we don't have to make a - # difference between dict and defaultdict - new_obj = cast(dict, new_obj) - # We can't iterate over obj.items() due to thread safety, i.e. the dicts length may - # change during the iteration - temp_dict = new_obj.copy() - new_obj.clear() - for k, val in temp_dict.items(): - new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo) - memo[obj_id] = new_obj - return new_obj - try: - if hasattr(obj, '__slots__'): - for attr_name in obj.__slots__: - setattr( - new_obj, - attr_name, - self._insert_bot( - self._insert_bot(getattr(new_obj, attr_name), memo), memo - ), - ) - if '__dict__' in obj.__slots__: - # In this case, we have already covered the case that obj has __dict__ - # Note that obj may have a __dict__ even if it's not in __slots__! - memo[obj_id] = new_obj - return new_obj - if hasattr(obj, '__dict__'): - for attr_name, attr in new_obj.__dict__.items(): - setattr(new_obj, attr_name, self._insert_bot(attr, memo)) - memo[obj_id] = new_obj - return new_obj - except Exception as exception: - warn( - f'Parsing of an object failed with the following exception: {exception}. ' - f'See the docs of BasePersistence.insert_bot for more information.', - PTBRuntimeWarning, - ) - - memo[obj_id] = obj - return obj - @abstractmethod def get_user_data(self) -> Dict[int, UD]: """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a @@ -628,6 +360,3 @@ class BasePersistence(Generic[UD, CD, BD], ABC): .. versionchanged:: 14.0 Changed this method into an ``@abstractmethod``. """ - - REPLACED_BOT: ClassVar[str] = 'bot_instance_replaced_by_ptb_persistence' - """:obj:`str`: Placeholder for :class:`telegram.Bot` instances replaced in saved data.""" diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index bf5e60495..4bbfca59a 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -20,6 +20,8 @@ from typing import Dict, Optional, Tuple, cast +from copy import deepcopy + from telegram.ext import BasePersistence, PersistenceInput from telegram._utils.types import JSONDict from telegram.ext._utils.types import ConversationDict, CDCData @@ -31,7 +33,7 @@ except ImportError: class DictPersistence(BasePersistence): - """Using Python's :obj:`dict` and ``json`` for making your bot persistent. + """Using Python's :obj:`dict` and :mod:`json` for making your bot persistent. Attention: The interface provided by this class is intended to be accessed exclusively by @@ -39,19 +41,13 @@ class DictPersistence(BasePersistence): interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`. Note: - This class does *not* implement a :meth:`flush` method, meaning that data managed by - ``DictPersistence`` is in-memory only and will be lost when the bot shuts down. This is, - because ``DictPersistence`` is mainly intended as starting point for custom persistence - classes that need to JSON-serialize the stored data before writing them to file/database. + * Data managed by :class:`DictPersistence` is in-memory only and will be lost when the bot + shuts down. This is, because :class:`DictPersistence` is mainly intended as starting + point for custom persistence classes that need to JSON-serialize the stored data before + writing them to file/database. - Warning: - :class:`DictPersistence` will try to replace :class:`telegram.Bot` instances by - :attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with - :meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure - that changes to the bot apply to the saved objects, too. If you change the bots token, this - may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see - :meth:`telegram.ext.BasePersistence.replace_bot` and - :meth:`telegram.ext.BasePersistence.insert_bot`. + * This implementation of :class:`BasePersistence` does not handle data that cannot be + serialized by :func:`json.dumps`. .. versionchanged:: 14.0 The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`. @@ -244,7 +240,7 @@ class DictPersistence(BasePersistence): """ if self.user_data is None: self._user_data = {} - return self.user_data # type: ignore[return-value] + return deepcopy(self.user_data) # type: ignore[arg-type] def get_chat_data(self) -> Dict[int, Dict[object, object]]: """Returns the chat_data created from the ``chat_data_json`` or an empty :obj:`dict`. @@ -254,7 +250,7 @@ class DictPersistence(BasePersistence): """ if self.chat_data is None: self._chat_data = {} - return self.chat_data # type: ignore[return-value] + return deepcopy(self.chat_data) # type: ignore[arg-type] def get_bot_data(self) -> Dict[object, object]: """Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`. @@ -264,7 +260,7 @@ class DictPersistence(BasePersistence): """ if self.bot_data is None: self._bot_data = {} - return self.bot_data # type: ignore[return-value] + return deepcopy(self.bot_data) # type: ignore[arg-type] def get_callback_data(self) -> Optional[CDCData]: """Returns the callback_data created from the ``callback_data_json`` or :obj:`None`. @@ -279,7 +275,7 @@ class DictPersistence(BasePersistence): if self.callback_data is None: self._callback_data = None return None - return self.callback_data[0], self.callback_data[1].copy() + return deepcopy((self.callback_data[0], self.callback_data[1].copy())) def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations created from the ``conversations_json`` or an empty @@ -320,7 +316,7 @@ class DictPersistence(BasePersistence): self._user_data = {} if self._user_data.get(user_id) == data: return - self._user_data[user_id] = data + self._user_data[user_id] = deepcopy(data) self._user_data_json = None def update_chat_data(self, chat_id: int, data: Dict) -> None: @@ -334,7 +330,7 @@ class DictPersistence(BasePersistence): self._chat_data = {} if self._chat_data.get(chat_id) == data: return - self._chat_data[chat_id] = data + self._chat_data[chat_id] = deepcopy(data) self._chat_data_json = None def update_bot_data(self, data: Dict) -> None: @@ -345,7 +341,7 @@ class DictPersistence(BasePersistence): """ if self._bot_data == data: return - self._bot_data = data + self._bot_data = deepcopy(data) self._bot_data_json = None def update_callback_data(self, data: CDCData) -> None: @@ -360,7 +356,7 @@ class DictPersistence(BasePersistence): """ if self._callback_data == data: return - self._callback_data = (data[0], data[1].copy()) + self._callback_data = deepcopy((data[0], data[1].copy())) self._callback_data_json = None def drop_chat_data(self, chat_id: int) -> None: diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index 487c3ae31..7622a82af 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -17,8 +17,11 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the PicklePersistence class.""" +import copyreg import pickle +from copy import deepcopy from pathlib import Path +from sys import version_info as py_ver from typing import ( Any, Dict, @@ -26,14 +29,105 @@ from typing import ( Tuple, overload, cast, + Type, + Set, + Callable, + TypeVar, ) +from telegram import Bot, TelegramObject from telegram._utils.types import FilePathInput +from telegram._utils.warnings import warn from telegram.ext import BasePersistence, PersistenceInput from telegram.ext._contexttypes import ContextTypes from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData +_REPLACED_KNOWN_BOT = "a known bot replaced by PTB's PicklePersistence" +_REPLACED_UNKNOWN_BOT = "an unknown bot replaced by PTB's PicklePersistence" + +TO = TypeVar('TO', bound=TelegramObject) + + +def _all_subclasses(cls: Type[TO]) -> Set[Type[TO]]: + """Gets all subclasses of the specified object, recursively. from + https://stackoverflow.com/a/3862957/9706202 + """ + subclasses = cls.__subclasses__() + return set(subclasses).union([s for c in subclasses for s in _all_subclasses(c)]) + + +def _reconstruct_to(cls: Type[TO], kwargs: dict) -> TO: + """ + This method is used for unpickling. The data, which is in the form a dictionary, is + converted back into a class. Works mostly the same as :meth:`TelegramObject.__setstate__`. + This function should be kept in place for backwards compatibility even if the pickling logic + is changed, since `_custom_reduction` places references to this function into the pickled data. + """ + obj = cls.__new__(cls) + obj.__setstate__(kwargs) + return obj # type: ignore[return-value] + + +def _custom_reduction(cls: TO) -> Tuple[Callable, Tuple[Type[TO], dict]]: + """ + This method is used for pickling. The bot attribute is preserved so _BotPickler().persistent_id + works as intended. + """ + data = cls._get_attrs(include_private=True) # pylint: disable=protected-access + return _reconstruct_to, (cls.__class__, data) + + +class _BotPickler(pickle.Pickler): + __slots__ = ('_bot',) + + def __init__(self, bot: Bot, *args: Any, **kwargs: Any): + self._bot = bot + if py_ver < (3, 8): # self.reducer_override is used above this version + # Here we define a private dispatch_table, because we want to preserve the bot + # attribute of objects so persistent_id works as intended. Otherwise, the bot attribute + # is deleted in __getstate__, which is used during regular pickling (via pickle.dumps) + self.dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] + for obj in _all_subclasses(TelegramObject): + self.dispatch_table[obj] = _custom_reduction # type: ignore[index] + super().__init__(*args, **kwargs) + + def reducer_override( # pylint: disable=no-self-use + self, obj: TO + ) -> Tuple[Callable, Tuple[Type[TO], dict]]: + if not isinstance(obj, TelegramObject): + return NotImplemented + + return _custom_reduction(obj) + + def persistent_id(self, obj: object) -> Optional[str]: + """Used to 'mark' the Bot, so it can be replaced later. See + https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id for more info + """ + if obj is self._bot: + return _REPLACED_KNOWN_BOT + if isinstance(obj, Bot): + warn('Unknown bot instance found. Will be replaced by `None` during unpickling') + return _REPLACED_UNKNOWN_BOT + return None # pickles as usual + + +class _BotUnpickler(pickle.Unpickler): + __slots__ = ('_bot',) + + def __init__(self, bot: Bot, *args: Any, **kwargs: Any): + self._bot = bot + super().__init__(*args, **kwargs) + + def persistent_load(self, pid: str) -> Optional[Bot]: + """Replaces the bot with the current bot if known, else it is replaced by :obj:`None`.""" + if pid == _REPLACED_KNOWN_BOT: + return self._bot + if pid == _REPLACED_UNKNOWN_BOT: + return None + raise pickle.UnpicklingError("Found unknown persistent id when unpickling!") + + class PicklePersistence(BasePersistence[UD, CD, BD]): """Using python's builtin pickle for making your bot persistent. @@ -42,14 +136,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): :class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`. - Warning: - :class:`PicklePersistence` will try to replace :class:`telegram.Bot` instances by - :attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with - :meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure - that changes to the bot apply to the saved objects, too. If you change the bots token, this - may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see - :meth:`telegram.ext.BasePersistence.replace_bot` and - :meth:`telegram.ext.BasePersistence.insert_bot`. + Note: + This implementation of :class:`BasePersistence` uses the functionality of the pickle module + to support serialization of bot instances. Specifically any reference to + :attr:`~BasePersistence.bot` will be replaced by a placeholder before pickling and + :attr:`~BasePersistence.bot` will be inserted back when loading the data. .. versionchanged:: 14.0 @@ -57,7 +148,6 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): * The parameter and attribute ``filename`` were replaced by :attr:`filepath`. * :attr:`filepath` now also accepts :obj:`pathlib.Path` as argument. - Args: filepath (:obj:`str` | :obj:`pathlib.Path`): The filepath for storing the pickle files. When :attr:`single_file` is :obj:`False` this will be used as a prefix. @@ -151,7 +241,8 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): def _load_singlefile(self) -> None: try: with self.filepath.open("rb") as file: - data = pickle.load(file) + data = _BotUnpickler(self.bot, file).load() + self.user_data = data['user_data'] self.chat_data = data['chat_data'] # For backwards compatibility with files not containing bot data @@ -170,11 +261,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): except Exception as exc: raise TypeError(f"Something went wrong unpickling {self.filepath.name}") from exc - @staticmethod - def _load_file(filepath: Path) -> Any: + def _load_file(self, filepath: Path) -> Any: try: with filepath.open("rb") as file: - return pickle.load(file) + return _BotUnpickler(self.bot, file).load() + except OSError: return None except pickle.UnpicklingError as exc: @@ -191,12 +282,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): 'callback_data': self.callback_data, } with self.filepath.open("wb") as file: - pickle.dump(data, file) + _BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data) - @staticmethod - def _dump_file(filepath: Path, data: object) -> None: + def _dump_file(self, filepath: Path, data: object) -> None: with filepath.open("wb") as file: - pickle.dump(data, file) + _BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data) def get_user_data(self) -> Dict[int, UD]: """Returns the user_data from the pickle file if it exists or an empty :obj:`dict`. @@ -213,7 +303,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self.user_data = data else: self._load_singlefile() - return self.user_data # type: ignore[return-value] + return deepcopy(self.user_data) # type: ignore[arg-type] def get_chat_data(self) -> Dict[int, CD]: """Returns the chat_data from the pickle file if it exists or an empty :obj:`dict`. @@ -230,7 +320,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self.chat_data = data else: self._load_singlefile() - return self.chat_data # type: ignore[return-value] + return deepcopy(self.chat_data) # type: ignore[arg-type] def get_bot_data(self) -> BD: """Returns the bot_data from the pickle file if it exists or an empty object of type @@ -248,7 +338,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self.bot_data = data else: self._load_singlefile() - return self.bot_data # type: ignore[return-value] + return deepcopy(self.bot_data) # type: ignore[return-value] def get_callback_data(self) -> Optional[CDCData]: """Returns the callback data from the pickle file if it exists or :obj:`None`. @@ -271,7 +361,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self._load_singlefile() if self.callback_data is None: return None - return self.callback_data[0], self.callback_data[1].copy() + return deepcopy((self.callback_data[0], self.callback_data[1].copy())) def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations from the pickle file if it exists or an empty dict. @@ -326,7 +416,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self.user_data = {} if self.user_data.get(user_id) == data: return - self.user_data[user_id] = data + self.user_data[user_id] = deepcopy(data) if not self.on_flush: if not self.single_file: self._dump_file(Path(f"{self.filepath}_user_data"), self.user_data) @@ -344,7 +434,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): self.chat_data = {} if self.chat_data.get(chat_id) == data: return - self.chat_data[chat_id] = data + self.chat_data[chat_id] = deepcopy(data) if not self.on_flush: if not self.single_file: self._dump_file(Path(f"{self.filepath}_chat_data"), self.chat_data) @@ -360,7 +450,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): """ if self.bot_data == data: return - self.bot_data = data + self.bot_data = deepcopy(data) if not self.on_flush: if not self.single_file: self._dump_file(Path(f"{self.filepath}_bot_data"), self.bot_data) @@ -380,7 +470,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): """ if self.callback_data == data: return - self.callback_data = (data[0], data[1].copy()) + self.callback_data = deepcopy((data[0], data[1].copy())) if not self.on_flush: if not self.single_file: self._dump_file(Path(f"{self.filepath}_callback_data"), self.callback_data) diff --git a/tests/test_bot.py b/tests/test_bot.py index 72e100c75..b1132d27a 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -19,6 +19,7 @@ import datetime import inspect import logging +import pickle import time import datetime as dtm from collections import defaultdict @@ -241,6 +242,10 @@ class TestBot: if bot.last_name: assert to_dict_bot["last_name"] == bot.last_name + def test_bot_pickling_error(self, bot): + with pytest.raises(pickle.PicklingError, match="Bot objects cannot be pickled"): + pickle.dumps(bot) + @pytest.mark.parametrize( 'bot_method_name', argvalues=[ diff --git a/tests/test_callbackdatacache.py b/tests/test_callbackdatacache.py index 1d97022d2..c21694832 100644 --- a/tests/test_callbackdatacache.py +++ b/tests/test_callbackdatacache.py @@ -24,7 +24,7 @@ from uuid import uuid4 import pytest import pytz -from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User, Chat from telegram.ext._callbackdatacache import ( CallbackDataCache, _KeyboardData, @@ -159,7 +159,8 @@ class TestCallbackDataCache: if invalid: callback_data_cache.clear_callback_data() - effective_message = Message(message_id=1, date=None, chat=None, reply_markup=out) + chat = Chat(1, 'private') + effective_message = Message(message_id=1, date=datetime.now(), chat=chat, reply_markup=out) effective_message.reply_to_message = deepcopy(effective_message) effective_message.pinned_message = deepcopy(effective_message) cq_id = uuid4().hex diff --git a/tests/test_persistence.py b/tests/test_persistence.py index fc7f32fcf..9f6f3d5f3 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -16,26 +16,25 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime import logging import os import pickle import gzip import signal -import uuid -from collections.abc import Container -from collections import defaultdict from pathlib import Path from time import sleep -from threading import Lock import pytest +from telegram.warnings import PTBUserWarning + try: import ujson as json except ImportError: import json -from telegram import Update, Message, User, Chat, MessageEntity, Bot +from telegram import Update, Message, User, Chat, MessageEntity, Bot, TelegramObject from telegram.ext import ( BasePersistence, ConversationHandler, @@ -130,7 +129,7 @@ def base_persistence(): @pytest.fixture(scope="function") def bot_persistence(): class BotPersistence(BasePersistence): - __slots__ = () + __slots__ = ('bot_data', 'chat_data', 'user_data', 'callback_data') def __init__(self): super().__init__() @@ -252,7 +251,6 @@ class TestBasePersistence: inst = bot_persistence for attr in inst.__slots__: assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - # assert not inst.__dict__, f"got missing slot(s): {inst.__dict__}" # The below test fails if the child class doesn't define __slots__ (not a cause of concern) assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @@ -590,292 +588,6 @@ class TestBasePersistence: dp.process_update(MyUpdate()) assert 'An uncaught error was raised while processing the update' not in caplog.text - def test_bot_replace_insert_bot(self, bot, bot_persistence): - class CustomSlottedClass: - __slots__ = ('bot', '__dict__') - - def __init__(self): - self.bot = bot - self.not_in_slots = bot - - def __eq__(self, other): - if isinstance(other, CustomSlottedClass): - return self.bot is other.bot and self.not_in_slots is other.not_in_slots - return False - - class DictNotInSlots(Container): - """This classes parent has slots, but __dict__ is not in those slots.""" - - def __init__(self): - self.bot = bot - - def __contains__(self, item): - return True - - def __eq__(self, other): - if isinstance(other, DictNotInSlots): - return self.bot is other.bot - return False - - class CustomClass: - def __init__(self): - self.bot = bot - self.slotted_object = CustomSlottedClass() - self.dict_not_in_slots_object = DictNotInSlots() - self.list_ = [1, 2, bot] - self.tuple_ = tuple(self.list_) - self.set_ = set(self.list_) - self.frozenset_ = frozenset(self.list_) - self.dict_ = {item: item for item in self.list_} - self.defaultdict_ = defaultdict(dict, self.dict_) - - @staticmethod - def replace_bot(): - cc = CustomClass() - cc.bot = BasePersistence.REPLACED_BOT - cc.slotted_object.bot = BasePersistence.REPLACED_BOT - cc.slotted_object.not_in_slots = BasePersistence.REPLACED_BOT - cc.dict_not_in_slots_object.bot = BasePersistence.REPLACED_BOT - cc.list_ = [1, 2, BasePersistence.REPLACED_BOT] - cc.tuple_ = tuple(cc.list_) - cc.set_ = set(cc.list_) - cc.frozenset_ = frozenset(cc.list_) - cc.dict_ = {item: item for item in cc.list_} - cc.defaultdict_ = defaultdict(dict, cc.dict_) - return cc - - def __eq__(self, other): - if isinstance(other, CustomClass): - return ( - self.bot is other.bot - and self.slotted_object == other.slotted_object - and self.dict_not_in_slots_object == other.dict_not_in_slots_object - and self.list_ == other.list_ - and self.tuple_ == other.tuple_ - and self.set_ == other.set_ - and self.frozenset_ == other.frozenset_ - and self.dict_ == other.dict_ - and self.defaultdict_ == other.defaultdict_ - ) - return False - - persistence = bot_persistence - persistence.set_bot(bot) - cc = CustomClass() - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1].bot == BasePersistence.REPLACED_BOT - assert persistence.bot_data[1] == cc.replace_bot() - - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1].bot == BasePersistence.REPLACED_BOT - assert persistence.chat_data[123][1] == cc.replace_bot() - - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1].bot == BasePersistence.REPLACED_BOT - assert persistence.user_data[123][1] == cc.replace_bot() - - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0].bot == BasePersistence.REPLACED_BOT - assert persistence.callback_data[0][0][2][0] == cc.replace_bot() - - assert persistence.get_bot_data()[1] == cc - assert persistence.get_bot_data()[1].bot is bot - assert persistence.get_chat_data()[123][1] == cc - assert persistence.get_chat_data()[123][1].bot is bot - assert persistence.get_user_data()[123][1] == cc - assert persistence.get_user_data()[123][1].bot is bot - assert persistence.get_callback_data()[0][0][2][0].bot is bot - assert persistence.get_callback_data()[0][0][2][0] == cc - - def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn): - """Here check that unpickable objects are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - def __copy__(self): - raise TypeError('UnhandledException') - - lock = Lock() - - persistence.update_bot_data({1: lock}) - assert persistence.bot_data[1] is lock - persistence.update_chat_data(123, {1: lock}) - assert persistence.chat_data[123][1] is lock - persistence.update_user_data(123, {1: lock}) - assert persistence.user_data[123][1] is lock - persistence.update_callback_data(([('1', 2, {0: lock})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is lock - - assert persistence.get_bot_data()[1] is lock - assert persistence.get_chat_data()[123][1] is lock - assert persistence.get_user_data()[123][1] is lock - assert persistence.get_callback_data()[0][0][2][0] is lock - - cc = CustomClass() - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1] is cc - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1] is cc - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1] is cc - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is cc - - assert persistence.get_bot_data()[1] is cc - assert persistence.get_chat_data()[123][1] is cc - assert persistence.get_user_data()[123][1] is cc - assert persistence.get_callback_data()[0][0][2][0] is cc - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "BasePersistence.replace_bot does not handle objects that can not be copied." - ) - assert str(recwarn[1].message).startswith( - "BasePersistence.insert_bot does not handle objects that can not be copied." - ) - - def test_bot_replace_insert_bot_unparsable_objects(self, bot, bot_persistence, recwarn): - """Here check that objects in __dict__ or __slots__ that can't - be parsed are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - uuid_obj = uuid.uuid4() - - persistence.update_bot_data({1: uuid_obj}) - assert persistence.bot_data[1] is uuid_obj - persistence.update_chat_data(123, {1: uuid_obj}) - assert persistence.chat_data[123][1] is uuid_obj - persistence.update_user_data(123, {1: uuid_obj}) - assert persistence.user_data[123][1] is uuid_obj - persistence.update_callback_data(([('1', 2, {0: uuid_obj})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is uuid_obj - - assert persistence.get_bot_data()[1] is uuid_obj - assert persistence.get_chat_data()[123][1] is uuid_obj - assert persistence.get_user_data()[123][1] is uuid_obj - assert persistence.get_callback_data()[0][0][2][0] is uuid_obj - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "Parsing of an object failed with the following exception: " - ) - assert str(recwarn[1].message).startswith( - "Parsing of an object failed with the following exception: " - ) - - def test_bot_replace_insert_bot_classes(self, bot, bot_persistence, recwarn): - """Here check that classes are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - pass - - persistence.update_bot_data({1: CustomClass}) - assert persistence.bot_data[1] is CustomClass - persistence.update_chat_data(123, {1: CustomClass}) - assert persistence.chat_data[123][1] is CustomClass - persistence.update_user_data(123, {1: CustomClass}) - assert persistence.user_data[123][1] is CustomClass - - assert persistence.get_bot_data()[1] is CustomClass - assert persistence.get_chat_data()[123][1] is CustomClass - assert persistence.get_user_data()[123][1] is CustomClass - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "BasePersistence.replace_bot does not handle classes such as 'CustomClass'" - ) - assert str(recwarn[1].message).startswith( - "BasePersistence.insert_bot does not handle classes such as 'CustomClass'" - ) - - def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence): - """Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - def __init__(self, data): - self.data = data - - def __eq__(self, other): - raise RuntimeError("Can't be compared") - - cc = CustomClass({1: bot, 2: 'foo'}) - expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'} - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1].data == expected - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1].data == expected - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1].data == expected - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0].data == expected - - expected = {1: bot, 2: 'foo'} - - assert persistence.get_bot_data()[1].data == expected - assert persistence.get_chat_data()[123][1].data == expected - assert persistence.get_user_data()[123][1].data == expected - assert persistence.get_callback_data()[0][0][2][0].data == expected - - @pytest.mark.filterwarnings('ignore:BasePersistence') - def test_replace_insert_bot_item_identity(self, bot, bot_persistence): - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomSlottedClass: - __slots__ = ('value',) - - def __init__(self): - self.value = 5 - - class CustomClass: - pass - - slot_object = CustomSlottedClass() - dict_object = CustomClass() - lock = Lock() - list_ = [slot_object, dict_object, lock] - tuple_ = (1, 2, 3) - dict_ = {1: slot_object, 2: dict_object} - - data = { - 'bot_1': bot, - 'bot_2': bot, - 'list_1': list_, - 'list_2': list_, - 'tuple_1': tuple_, - 'tuple_2': tuple_, - 'dict_1': dict_, - 'dict_2': dict_, - } - - def make_assertion(data_): - return ( - data_['bot_1'] is data_['bot_2'] - and data_['list_1'] is data_['list_2'] - and data_['list_1'][0] is data_['list_2'][0] - and data_['list_1'][1] is data_['list_2'][1] - and data_['list_1'][2] is data_['list_2'][2] - and data_['tuple_1'] is data_['tuple_2'] - and data_['dict_1'] is data_['dict_2'] - and data_['dict_1'][1] is data_['dict_2'][1] - and data_['dict_1'][1] is data_['list_1'][0] - and data_['dict_1'][2] is data_['list_1'][1] - and data_['dict_1'][2] is data_['dict_2'][2] - ) - - persistence.update_bot_data(data) - assert make_assertion(persistence.bot_data) - assert make_assertion(persistence.get_bot_data()) - def test_set_bot_exception(self, bot): non_ext_bot = Bot(bot.token) persistence = OwnPersistence() @@ -1033,11 +745,28 @@ def pickle_files_wo_callback_data(user_data, chat_data, bot_data, conversations) def update(bot): user = User(id=321, first_name='test_user', is_bot=False) chat = Chat(id=123, type='group') - message = Message(1, None, chat, from_user=user, text="Hi there", bot=bot) + message = Message(1, datetime.datetime.now(), chat, from_user=user, text="Hi there", bot=bot) return Update(0, message=message) class TestPicklePersistence: + class DictSub(TelegramObject): # Used for testing our custom (Un)Pickler. + def __init__(self, private, normal, b): + self._private = private + self.normal = normal + self._bot = b + + class SlotsSub(TelegramObject): + __slots__ = ('new_var', '_private') + + def __init__(self, new_var, private): + self.new_var = new_var + self._private = private + + class NormalClass: + def __init__(self, my_var): + self.my_var = my_var + def test_slot_behaviour(self, mro_slots, pickle_persistence): inst = pickle_persistence for attr in inst.__slots__: @@ -1046,7 +775,7 @@ class TestPicklePersistence: def test_pickle_behaviour_with_slots(self, pickle_persistence): bot_data = pickle_persistence.get_bot_data() - bot_data['message'] = Message(3, None, Chat(2, type='supergroup')) + bot_data['message'] = Message(3, datetime.datetime.now(), Chat(2, type='supergroup')) pickle_persistence.update_bot_data(bot_data) retrieved = pickle_persistence.get_bot_data() assert retrieved == bot_data @@ -1633,13 +1362,113 @@ class TestPicklePersistence: conversations_test = dict(pickle.load(f))['conversations'] assert conversations_test['name1'] == conversation1 + def test_custom_pickler_unpickler_simple( + self, pickle_persistence, update, good_pickle_files, bot, recwarn + ): + pickle_persistence.bot = bot # assign the current bot to the persistence + data_with_bot = {'current_bot': update.message} + pickle_persistence.update_chat_data(12345, data_with_bot) # also calls BotPickler.dumps() + + # Test that regular pickle load fails - + err_msg = ( + "A load persistent id instruction was encountered,\nbut no persistent_load " + "function was specified." + ) + with pytest.raises(pickle.UnpicklingError, match=err_msg): + with open('pickletest_chat_data', 'rb') as f: + pickle.load(f) + + # Test that our custom unpickler works as intended -- inserts the current bot + # We have to create a new instance otherwise unpickling is skipped + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.bot = bot # Set the bot + assert pp.get_chat_data()[12345]['current_bot'].get_bot() is bot + + # Now test that pickling of unknown bots in TelegramObjects will be replaced by None- + assert not len(recwarn) + data_with_bot['unknown_bot_in_user'] = User(1, 'Dev', False, bot=Bot('1234:abcd')) + pickle_persistence.update_chat_data(12345, data_with_bot) + assert len(recwarn) == 1 + assert recwarn[-1].category is PTBUserWarning + assert str(recwarn[-1].message).startswith("Unknown bot instance found.") + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.bot = bot + assert pp.get_chat_data()[12345]['unknown_bot_in_user']._bot is None + + def test_custom_pickler_unpickler_with_custom_objects( + self, bot, pickle_persistence, good_pickle_files + ): + dict_s = self.DictSub("private", 'normal', bot) + slot_s = self.SlotsSub("new_var", 'private_var') + regular = self.NormalClass(12) + + pickle_persistence.bot = bot + pickle_persistence.update_user_data( + 1232, {'sub_dict': dict_s, 'sub_slots': slot_s, 'r': regular} + ) + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.bot = bot # Set the bot + data = pp.get_user_data()[1232] + sub_dict = data['sub_dict'] + sub_slots = data['sub_slots'] + sub_regular = data['r'] + assert sub_dict._bot is bot + assert sub_dict.normal == dict_s.normal + assert sub_dict._private == dict_s._private + assert sub_slots.new_var == slot_s.new_var + assert sub_slots._private == slot_s._private + assert sub_slots._bot is None # We didn't set the bot, so it shouldn't have it here. + assert sub_regular.my_var == regular.my_var + + def test_custom_pickler_unpickler_with_handler_integration( + self, bot, update, pickle_persistence, good_pickle_files, recwarn + ): + u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build() + dp = u.dispatcher + bot_id = None + + def first(update, context): + nonlocal bot_id + bot_id = update.message.get_bot() + # Test pickling a message object, which has the current bot + context.user_data['msg'] = update.message + # Test pickling a bot, which is not known. Directly serializing bots will fail. + new_chat = Chat(1, 'private', bot=Bot('1234:abcd')) + context.chat_data['unknown_bot_in_chat'] = new_chat + + def second(_, context): + msg = context.user_data['msg'] + assert bot_id is msg.get_bot() # Tests if the same bot is inserted by the unpickler + new_none_bot = context.chat_data['unknown_bot_in_chat']._bot + assert new_none_bot is None + + h1 = MessageHandler(None, first) + h2 = MessageHandler(None, second) + dp.add_handler(h1) + + assert not len(recwarn) + dp.process_update(update) + assert len(recwarn) == 1 + assert recwarn[-1].category is PTBUserWarning + assert str(recwarn[-1].message).startswith("Unknown bot instance found.") + + pickle_persistence_2 = PicklePersistence( # initialize a new persistence for unpickling + filepath='pickletest', + single_file=False, + on_flush=False, + ) + u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_2).build() + dp = u.dispatcher + dp.add_handler(h2) + dp.process_update(update) + def test_with_handler(self, bot, update, bot_data, pickle_persistence, good_pickle_files): u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build() dp = u.dispatcher bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - def first(update, context): + def first(_, context): if not context.user_data == {}: pytest.fail() if not context.chat_data == {}: @@ -1653,7 +1482,7 @@ class TestPicklePersistence: context.bot_data['test1'] = 'test0' context.bot.callback_data_cache._callback_queries['test1'] = 'test0' - def second(update, context): + def second(_, context): if not context.user_data['test1'] == 'test2': pytest.fail() if not context.chat_data['test3'] == 'test4': @@ -1813,7 +1642,7 @@ class TestPicklePersistence: assert ch.conversations[ch._get_key(update)] == 0 assert ch.conversations == pickle_persistence.conversations['name2'] - def test_with_nested_conversationHandler( + def test_with_nested_conversation_handler( self, dp, update, good_pickle_files, pickle_persistence ): dp.persistence = pickle_persistence diff --git a/tests/test_slots.py b/tests/test_slots.py index f1168c34c..8ab532f2c 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -25,7 +25,6 @@ import inspect included = { # These modules/classes intentionally have __dict__. 'CallbackContext', - 'BasePersistence', } diff --git a/tests/test_telegramobject.py b/tests/test_telegramobject.py index 0d0df8acb..06319d944 100644 --- a/tests/test_telegramobject.py +++ b/tests/test_telegramobject.py @@ -16,20 +16,28 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. - +import datetime import json as json_lib +import pickle import pytest +from copy import deepcopy try: import ujson except ImportError: ujson = None -from telegram import TelegramObject, Message, Chat, User +from telegram import TelegramObject, Message, Chat, User, PhotoSize class TestTelegramObject: + class Sub(TelegramObject): + def __init__(self, private, normal, b): + self._private = private + self.normal = normal + self._bot = b + def test_to_json_native(self, monkeypatch): if ujson: monkeypatch.setattr('ujson.dumps', json_lib.dumps) @@ -145,3 +153,47 @@ class TestTelegramObject: assert message['from_user'] is user with pytest.raises(KeyError, match="Message don't have an attribute called `no_key`"): message['no_key'] + + def test_pickle(self, bot): + chat = Chat(2, Chat.PRIVATE) + user = User(3, 'first_name', False) + date = datetime.datetime.now() + photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot) + msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo]) + + # Test pickling of TGObjects, we choose Message since it's contains the most subclasses. + assert msg.get_bot() + unpickled = pickle.loads(pickle.dumps(msg)) + + with pytest.raises(RuntimeError): + unpickled.get_bot() # There should be no bot when we pickle TGObjects + + assert unpickled.chat == chat + assert unpickled.from_user == user + assert unpickled.date == date + assert unpickled.photo[0] == photo + + def test_deepcopy_telegram_obj(self, bot): + chat = Chat(2, Chat.PRIVATE) + user = User(3, 'first_name', False) + date = datetime.datetime.now() + photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot) + msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo]) + + new_msg = deepcopy(msg) + + # The same bot should be present when deepcopying. + assert new_msg.get_bot() == bot and new_msg.get_bot() is bot + + assert new_msg.date == date and new_msg.date is not date + assert new_msg.chat == chat and new_msg.chat is not chat + assert new_msg.from_user == user and new_msg.from_user is not user + assert new_msg.photo[0] == photo and new_msg.photo[0] is not photo + + def test_deepcopy_subclass_telegram_obj(self, bot): + s = self.Sub("private", 'normal', bot) + d = deepcopy(s) + assert d is not s + assert d._private == s._private # Can't test for identity since two equal strings is True + assert d._bot == s._bot and d._bot is s._bot + assert d.normal == s.normal