mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-28 07:20:17 +01:00
Persistence of Bots
: Refactor Automatic Replacement and Integration with TelegramObject
(#2893)
This commit is contained in:
parent
7b37f9a6fa
commit
a743726b08
12 changed files with 464 additions and 690 deletions
|
@ -22,6 +22,7 @@
|
|||
|
||||
import functools
|
||||
import logging
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from typing import (
|
||||
|
@ -37,6 +38,7 @@ from typing import (
|
|||
cast,
|
||||
Sequence,
|
||||
Any,
|
||||
NoReturn,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -123,10 +125,13 @@ class Bot(TelegramObject):
|
|||
considered equal, if their :attr:`bot` is equal.
|
||||
|
||||
Note:
|
||||
Most bot methods have the argument ``api_kwargs`` which allows to pass arbitrary keywords
|
||||
to the Telegram API. This can be used to access new features of the API before they were
|
||||
incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for
|
||||
passing files.
|
||||
* Most bot methods have the argument ``api_kwargs`` which allows passing arbitrary keywords
|
||||
to the Telegram API. This can be used to access new features of the API before they are
|
||||
incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for
|
||||
passing files.
|
||||
* Bots should not be serialized since if you for e.g. change the bots token, then your
|
||||
serialized instance will not reflect that change. Trying to pickle a bot instance will
|
||||
raise :exc:`pickle.PicklingError`.
|
||||
|
||||
.. versionchanged:: 14.0
|
||||
|
||||
|
@ -136,6 +141,7 @@ class Bot(TelegramObject):
|
|||
* Removed the deprecated ``defaults`` parameter. If you want to use
|
||||
:class:`telegram.ext.Defaults`, please use the subclass :class:`telegram.ext.ExtBot`
|
||||
instead.
|
||||
* Attempting to pickle a bot instance will now raise :exc:`pickle.PicklingError`.
|
||||
|
||||
Args:
|
||||
token (:obj:`str`): Bot's unique authentication.
|
||||
|
@ -157,7 +163,7 @@ class Bot(TelegramObject):
|
|||
'private_key',
|
||||
'_bot_user',
|
||||
'_request',
|
||||
'logger',
|
||||
'_logger',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -176,7 +182,7 @@ class Bot(TelegramObject):
|
|||
self._bot_user: Optional[User] = None
|
||||
self._request = request or Request()
|
||||
self.private_key = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
if private_key:
|
||||
if not CRYPTO_INSTALLED:
|
||||
|
@ -188,6 +194,10 @@ class Bot(TelegramObject):
|
|||
private_key, password=private_key_password, backend=default_backend()
|
||||
)
|
||||
|
||||
def __reduce__(self) -> NoReturn:
|
||||
"""Called by pickle.dumps(). Serializing bots is unadvisable, so we forbid pickling."""
|
||||
raise pickle.PicklingError('Bot objects cannot be pickled!')
|
||||
|
||||
# TODO: After https://youtrack.jetbrains.com/issue/PY-50952 is fixed, we can revisit this and
|
||||
# consider adding Paramspec from typing_extensions to properly fix this. Currently a workaround
|
||||
def _log(func: Any): # type: ignore[no-untyped-def] # skipcq: PY-D0003
|
||||
|
@ -2999,9 +3009,9 @@ class Bot(TelegramObject):
|
|||
)
|
||||
|
||||
if result:
|
||||
self.logger.debug('Getting updates: %s', [u['update_id'] for u in result])
|
||||
self._logger.debug('Getting updates: %s', [u['update_id'] for u in result])
|
||||
else:
|
||||
self.logger.debug('No new updates found.')
|
||||
self._logger.debug('No new updates found.')
|
||||
|
||||
return Update.de_list(result, self) # type: ignore[return-value]
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class _BaseMedium(TelegramObject):
|
|||
|
||||
"""
|
||||
|
||||
__slots__ = ('bot', 'file_id', 'file_size', 'file_unique_id')
|
||||
__slots__ = ('file_id', 'file_size', 'file_unique_id')
|
||||
|
||||
def __init__(
|
||||
self, file_id: str, file_unique_id: str, file_size: int = None, bot: 'Bot' = None
|
||||
|
|
|
@ -210,14 +210,14 @@ class Credentials(TelegramObject):
|
|||
nonce (:obj:`str`): Bot-specified nonce
|
||||
"""
|
||||
|
||||
__slots__ = ('bot', 'nonce', 'secure_data')
|
||||
__slots__ = ('nonce', 'secure_data')
|
||||
|
||||
def __init__(self, secure_data: 'SecureData', nonce: str, bot: 'Bot' = None, **_kwargs: Any):
|
||||
# Required
|
||||
self.secure_data = secure_data
|
||||
self.nonce = nonce
|
||||
|
||||
self.bot = bot
|
||||
self.set_bot(bot)
|
||||
|
||||
@classmethod
|
||||
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['Credentials']:
|
||||
|
@ -261,7 +261,6 @@ class SecureData(TelegramObject):
|
|||
"""
|
||||
|
||||
__slots__ = (
|
||||
'bot',
|
||||
'utility_bill',
|
||||
'personal_details',
|
||||
'temporary_registration',
|
||||
|
@ -304,7 +303,7 @@ class SecureData(TelegramObject):
|
|||
self.passport = passport
|
||||
self.personal_details = personal_details
|
||||
|
||||
self.bot = bot
|
||||
self.set_bot(bot)
|
||||
|
||||
@classmethod
|
||||
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureData']:
|
||||
|
@ -360,7 +359,7 @@ class SecureValue(TelegramObject):
|
|||
|
||||
"""
|
||||
|
||||
__slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation', 'bot')
|
||||
__slots__ = ('data', 'front_side', 'reverse_side', 'selfie', 'files', 'translation')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -380,7 +379,7 @@ class SecureValue(TelegramObject):
|
|||
self.files = files
|
||||
self.translation = translation
|
||||
|
||||
self.bot = bot
|
||||
self.set_bot(bot)
|
||||
|
||||
@classmethod
|
||||
def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['SecureValue']:
|
||||
|
@ -412,17 +411,17 @@ class SecureValue(TelegramObject):
|
|||
class _CredentialsBase(TelegramObject):
|
||||
"""Base class for DataCredentials and FileCredentials."""
|
||||
|
||||
__slots__ = ('hash', 'secret', 'file_hash', 'data_hash', 'bot')
|
||||
__slots__ = ('hash', 'secret', 'file_hash', 'data_hash')
|
||||
|
||||
def __init__(self, hash: str, secret: str, bot: 'Bot' = None, **_kwargs: Any):
|
||||
self.hash = hash
|
||||
self.secret = secret
|
||||
|
||||
# Aliases just be be sure
|
||||
# Aliases just to be sure
|
||||
self.file_hash = self.hash
|
||||
self.data_hash = self.hash
|
||||
|
||||
self.bot = bot
|
||||
self.set_bot(bot)
|
||||
|
||||
|
||||
class DataCredentials(_CredentialsBase):
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""Base class for Telegram Objects."""
|
||||
from copy import deepcopy
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json # type: ignore[no-redef]
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Tuple, Dict, Union
|
||||
|
||||
from telegram._utils.types import JSONDict
|
||||
from telegram._utils.warnings import warn
|
||||
|
@ -40,6 +42,12 @@ class TelegramObject:
|
|||
is equivalent to ``telegram_object.attribute_name``. If the object does not have an attribute
|
||||
with the appropriate name, a :exc:`KeyError` will be raised.
|
||||
|
||||
When objects of this type are pickled, the :class:`~telegram.Bot` attribute associated with the
|
||||
object will be removed. However, when copying the object via :func:`copy.deepcopy`, the copy
|
||||
will have the *same* bot instance associated with it, i.e::
|
||||
|
||||
assert telegram_object.get_bot() is copy.deepcopy(telegram_object).get_bot()
|
||||
|
||||
.. versionchanged:: 14.0
|
||||
``telegram_object['from']`` will look up the key ``from_user``. This is to account for
|
||||
special cases like :attr:`Message.from_user` that deviate from the official Bot API.
|
||||
|
@ -53,15 +61,12 @@ class TelegramObject:
|
|||
_bot: Optional['Bot']
|
||||
# Adding slots reduces memory usage & allows for faster attribute access.
|
||||
# Only instance variables should be added to __slots__.
|
||||
__slots__ = (
|
||||
'_id_attrs',
|
||||
'_bot',
|
||||
)
|
||||
__slots__ = ('_id_attrs', '_bot')
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __new__(cls, *args: object, **kwargs: object) -> 'TelegramObject':
|
||||
# We add _id_attrs in __new__ instead of __init__ since we want to add this to the slots
|
||||
# w/o calling __init__ in all of the subclasses. This is what we also do in BaseFilter.
|
||||
# w/o calling __init__ in all of the subclasses.
|
||||
instance = super().__new__(cls)
|
||||
instance._id_attrs = ()
|
||||
instance._bot = None
|
||||
|
@ -81,6 +86,86 @@ class TelegramObject:
|
|||
f"`{item}`."
|
||||
) from exc
|
||||
|
||||
def __getstate__(self) -> Dict[str, Union[str, object]]:
|
||||
"""
|
||||
This method is used for pickling. We remove the bot attribute of the object since those
|
||||
are not pickable.
|
||||
"""
|
||||
return self._get_attrs(include_private=True, recursive=False, remove_bot=True)
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
"""
|
||||
This method is used for unpickling. The data, which is in the form a dictionary, is
|
||||
converted back into a class. Should be modified in place.
|
||||
"""
|
||||
for key, val in state.items():
|
||||
setattr(self, key, val)
|
||||
|
||||
def __deepcopy__(self: TO, memodict: dict) -> TO:
|
||||
"""This method deepcopies the object and sets the bot on the newly created copy."""
|
||||
bot = self._bot # Save bot so we can set it after copying
|
||||
self.set_bot(None) # set to None so it is not deepcopied
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls) # create a new instance
|
||||
memodict[id(self)] = result # save the id of the object in the dict
|
||||
|
||||
attrs = self._get_attrs(include_private=True) # get all its attributes
|
||||
|
||||
for k in attrs: # now we set the attributes in the deepcopied object
|
||||
setattr(result, k, deepcopy(getattr(self, k), memodict))
|
||||
|
||||
result.set_bot(bot) # Assign the bots back
|
||||
self.set_bot(bot)
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _get_attrs(
|
||||
self,
|
||||
include_private: bool = False,
|
||||
recursive: bool = False,
|
||||
remove_bot: bool = False,
|
||||
) -> Dict[str, Union[str, object]]:
|
||||
"""This method is used for obtaining the attributes of the object.
|
||||
|
||||
Args:
|
||||
include_private (:obj:`bool`): Whether the result should include private variables.
|
||||
recursive (:obj:`bool`): If :obj:`True`, will convert any TelegramObjects (if found) in
|
||||
the attributes to a dictionary. Else, preserves it as an object itself.
|
||||
remove_bot (:obj:`bool`): Whether the bot should be included in the result.
|
||||
|
||||
Returns:
|
||||
:obj:`dict`: A dict where the keys are attribute names and values are their values.
|
||||
"""
|
||||
data = {}
|
||||
|
||||
if not recursive:
|
||||
try:
|
||||
# __dict__ has attrs from superclasses, so no need to put in the for loop below
|
||||
data.update(self.__dict__)
|
||||
except AttributeError:
|
||||
pass
|
||||
# We want to get all attributes for the class, using self.__slots__ only includes the
|
||||
# attributes used by that class itself, and not its superclass(es). Hence, we get its MRO
|
||||
# and then get their attributes. The `[:-1]` slice excludes the `object` class
|
||||
for cls in self.__class__.__mro__[:-1]:
|
||||
for key in cls.__slots__:
|
||||
if not include_private and key.startswith('_'):
|
||||
continue
|
||||
|
||||
value = getattr(self, key, None)
|
||||
if value is not None:
|
||||
if recursive and hasattr(value, 'to_dict'):
|
||||
data[key] = value.to_dict()
|
||||
else:
|
||||
data[key] = value
|
||||
elif not recursive:
|
||||
data[key] = value
|
||||
|
||||
if recursive and data.get('from_user'):
|
||||
data['from'] = data.pop('from_user', None)
|
||||
if remove_bot:
|
||||
data.pop('_bot', None)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _parse_data(data: Optional[JSONDict]) -> Optional[JSONDict]:
|
||||
return None if data is None else data.copy()
|
||||
|
@ -137,27 +222,7 @@ class TelegramObject:
|
|||
Returns:
|
||||
:obj:`dict`
|
||||
"""
|
||||
data = {}
|
||||
|
||||
# We want to get all attributes for the class, using self.__slots__ only includes the
|
||||
# attributes used by that class itself, and not its superclass(es). Hence we get its MRO
|
||||
# and then get their attributes. The `[:-2]` slice excludes the `object` class & the
|
||||
# TelegramObject class itself.
|
||||
attrs = {attr for cls in self.__class__.__mro__[:-2] for attr in cls.__slots__}
|
||||
for key in attrs:
|
||||
if key == 'bot' or key.startswith('_'):
|
||||
continue
|
||||
|
||||
value = getattr(self, key, None)
|
||||
if value is not None:
|
||||
if hasattr(value, 'to_dict'):
|
||||
data[key] = value.to_dict()
|
||||
else:
|
||||
data[key] = value
|
||||
|
||||
if data.get('from_user'):
|
||||
data['from'] = data.pop('from_user', None)
|
||||
return data
|
||||
return self._get_attrs(recursive=True)
|
||||
|
||||
def get_bot(self) -> 'Bot':
|
||||
"""Returns the :class:`telegram.Bot` instance associated with this object.
|
||||
|
@ -171,8 +236,7 @@ class TelegramObject:
|
|||
"""
|
||||
if self._bot is None:
|
||||
raise RuntimeError(
|
||||
'This object has no bot associated with it. \
|
||||
Shortcuts cannot be used.'
|
||||
'This object has no bot associated with it. Shortcuts cannot be used.'
|
||||
)
|
||||
return self._bot
|
||||
|
||||
|
|
|
@ -18,14 +18,11 @@
|
|||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""This module contains the BasePersistence class."""
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, NamedTuple
|
||||
from typing import Dict, Optional, Tuple, Generic, NamedTuple
|
||||
|
||||
from telegram import Bot
|
||||
from telegram.ext import ExtBot
|
||||
|
||||
from telegram.warnings import PTBRuntimeWarning
|
||||
from telegram._utils.warnings import warn
|
||||
from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData
|
||||
|
||||
|
||||
|
@ -73,33 +70,39 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
|
|||
* :meth:`get_chat_data`
|
||||
* :meth:`update_chat_data`
|
||||
* :meth:`refresh_chat_data`
|
||||
* :meth:`drop_chat_data`
|
||||
* :meth:`get_user_data`
|
||||
* :meth:`update_user_data`
|
||||
* :meth:`refresh_user_data`
|
||||
* :meth:`drop_user_data`
|
||||
* :meth:`get_callback_data`
|
||||
* :meth:`update_callback_data`
|
||||
* :meth:`get_conversations`
|
||||
* :meth:`update_conversation`
|
||||
* :meth:`flush`
|
||||
|
||||
If you don't actually need one of those methods, a simple ``pass`` is enough. For example, if
|
||||
you don't store ``bot_data``, you don't need :meth:`get_bot_data`, :meth:`update_bot_data` or
|
||||
:meth:`refresh_bot_data`.
|
||||
|
||||
Warning:
|
||||
Persistence will try to replace :class:`telegram.Bot` instances by :attr:`REPLACED_BOT` and
|
||||
insert the bot set with :meth:`set_bot` upon loading of the data. This is to ensure that
|
||||
changes to the bot apply to the saved objects, too. If you change the bots token, this may
|
||||
lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
|
||||
:meth:`replace_bot` and :meth:`insert_bot`.
|
||||
If you don't actually need one of those methods, a simple :keyword:`pass` is enough.
|
||||
For example, if you don't store ``bot_data``, you don't need :meth:`get_bot_data`,
|
||||
:meth:`update_bot_data` or :meth:`refresh_bot_data`.
|
||||
|
||||
Note:
|
||||
:meth:`replace_bot` and :meth:`insert_bot` are used *independently* of the implementation
|
||||
of the ``update/get_*`` methods, i.e. you don't need to worry about it while
|
||||
implementing a custom persistence subclass.
|
||||
You should avoid saving :class:`telegram.Bot` instances. This is because if you change e.g.
|
||||
the bots token, this won't propagate to the serialized instances and may lead to exceptions.
|
||||
|
||||
To prevent this, the implementation may use :attr:`bot` to replace bot instances with a
|
||||
placeholder before serialization and insert :attr:`bot` back when loading the data.
|
||||
Since :attr:`bot` will be set when the process starts, this will be the up-to-date bot
|
||||
instance.
|
||||
|
||||
If the persistence implementation does not take care of this, you should make sure not to
|
||||
store any bot instances in the data that will be persisted. E.g. in case of
|
||||
:class:`telegram.TelegramObject`, one may call :meth:`set_bot` to ensure that shortcuts like
|
||||
:meth:`telegram.Message.reply_text` are available.
|
||||
|
||||
.. versionchanged:: 14.0
|
||||
The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
|
||||
* The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
|
||||
* ``insert/replace_bot`` was dropped. Serialization of bot instances now needs to be
|
||||
handled by the specific implementation - see above note.
|
||||
|
||||
Args:
|
||||
store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be
|
||||
|
@ -109,72 +112,10 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
|
|||
Attributes:
|
||||
store_data (:class:`PersistenceInput`): Specifies which kinds of data will be saved by this
|
||||
persistence instance.
|
||||
bot (:class:`telegram.Bot`): The bot associated with the persistence.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'bot',
|
||||
'store_data',
|
||||
'__dict__', # __dict__ is included because we replace methods in the __new__
|
||||
)
|
||||
|
||||
def __new__(
|
||||
cls, *args: object, **kwargs: object # pylint: disable=unused-argument
|
||||
) -> 'BasePersistence':
|
||||
"""This overrides the get_* and update_* methods to use insert/replace_bot.
|
||||
That has the side effect that we always pass deepcopied data to those methods, so in
|
||||
Pickle/DictPersistence we don't have to worry about copying the data again.
|
||||
|
||||
Note: This doesn't hold for second tuple-entry of callback_data. That's a Dict[str, str],
|
||||
so no bots to replace anyway.
|
||||
"""
|
||||
instance = super().__new__(cls)
|
||||
get_user_data = instance.get_user_data
|
||||
get_chat_data = instance.get_chat_data
|
||||
get_bot_data = instance.get_bot_data
|
||||
get_callback_data = instance.get_callback_data
|
||||
update_user_data = instance.update_user_data
|
||||
update_chat_data = instance.update_chat_data
|
||||
update_bot_data = instance.update_bot_data
|
||||
update_callback_data = instance.update_callback_data
|
||||
|
||||
def get_user_data_insert_bot() -> Dict[int, UD]:
|
||||
return instance.insert_bot(get_user_data())
|
||||
|
||||
def get_chat_data_insert_bot() -> Dict[int, CD]:
|
||||
return instance.insert_bot(get_chat_data())
|
||||
|
||||
def get_bot_data_insert_bot() -> BD:
|
||||
return instance.insert_bot(get_bot_data())
|
||||
|
||||
def get_callback_data_insert_bot() -> Optional[CDCData]:
|
||||
cdc_data = get_callback_data()
|
||||
if cdc_data is None:
|
||||
return None
|
||||
return instance.insert_bot(cdc_data[0]), cdc_data[1]
|
||||
|
||||
def update_user_data_replace_bot(user_id: int, data: UD) -> None:
|
||||
return update_user_data(user_id, instance.replace_bot(data))
|
||||
|
||||
def update_chat_data_replace_bot(chat_id: int, data: CD) -> None:
|
||||
return update_chat_data(chat_id, instance.replace_bot(data))
|
||||
|
||||
def update_bot_data_replace_bot(data: BD) -> None:
|
||||
return update_bot_data(instance.replace_bot(data))
|
||||
|
||||
def update_callback_data_replace_bot(data: CDCData) -> None:
|
||||
obj_data, queue = data
|
||||
return update_callback_data((instance.replace_bot(obj_data), queue))
|
||||
|
||||
# Adds to __dict__
|
||||
setattr(instance, 'get_user_data', get_user_data_insert_bot)
|
||||
setattr(instance, 'get_chat_data', get_chat_data_insert_bot)
|
||||
setattr(instance, 'get_bot_data', get_bot_data_insert_bot)
|
||||
setattr(instance, 'get_callback_data', get_callback_data_insert_bot)
|
||||
setattr(instance, 'update_user_data', update_user_data_replace_bot)
|
||||
setattr(instance, 'update_chat_data', update_chat_data_replace_bot)
|
||||
setattr(instance, 'update_bot_data', update_bot_data_replace_bot)
|
||||
setattr(instance, 'update_callback_data', update_callback_data_replace_bot)
|
||||
return instance
|
||||
__slots__ = ('bot', 'store_data')
|
||||
|
||||
def __init__(self, store_data: PersistenceInput = None):
|
||||
self.store_data = store_data or PersistenceInput()
|
||||
|
@ -192,215 +133,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
|
|||
|
||||
self.bot = bot
|
||||
|
||||
@classmethod
|
||||
def replace_bot(cls, obj: object) -> object:
|
||||
"""
|
||||
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
|
||||
:attr:`REPLACED_BOT`. Currently, this handles objects of type :class:`list`,
|
||||
:class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`,
|
||||
:class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or
|
||||
:data:`~object.__slots__` attribute, excluding classes and objects that can't be copied
|
||||
with :func:`copy.copy`. If the parsing of an object fails, the object will be returned
|
||||
unchanged and the error will be logged.
|
||||
|
||||
Args:
|
||||
obj (:obj:`object`): The object
|
||||
|
||||
Returns:
|
||||
:class:`object`: Copy of the object with Bot instances replaced.
|
||||
"""
|
||||
return cls._replace_bot(obj, {})
|
||||
|
||||
@classmethod
|
||||
def _replace_bot( # pylint: disable=too-many-return-statements
|
||||
cls, obj: object, memo: Dict[int, object]
|
||||
) -> object:
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, Bot):
|
||||
memo[obj_id] = cls.REPLACED_BOT
|
||||
return cls.REPLACED_BOT
|
||||
if isinstance(obj, (list, set)):
|
||||
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
|
||||
# over doesn't change its length during the iteration
|
||||
temp_iterable = obj.copy()
|
||||
new_iterable = obj.__class__(cls._replace_bot(item, memo) for item in temp_iterable)
|
||||
memo[obj_id] = new_iterable
|
||||
return new_iterable
|
||||
if isinstance(obj, (tuple, frozenset)):
|
||||
# tuples and frozensets are immutable so we don't need to worry about thread safety
|
||||
new_immutable = obj.__class__(cls._replace_bot(item, memo) for item in obj)
|
||||
memo[obj_id] = new_immutable
|
||||
return new_immutable
|
||||
if isinstance(obj, type):
|
||||
# classes usually do have a __dict__, but it's not writable
|
||||
warn(
|
||||
f'BasePersistence.replace_bot does not handle classes such as {obj.__name__!r}. '
|
||||
'See the docs of BasePersistence.replace_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
return obj
|
||||
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
memo[obj_id] = new_obj
|
||||
except Exception:
|
||||
warn(
|
||||
'BasePersistence.replace_bot does not handle objects that can not be copied. See '
|
||||
'the docs of BasePersistence.replace_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
if isinstance(obj, dict):
|
||||
# We handle dicts via copy(obj) so we don't have to make a
|
||||
# difference between dict and defaultdict
|
||||
new_obj = cast(dict, new_obj)
|
||||
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
|
||||
# change during the iteration
|
||||
temp_dict = new_obj.copy()
|
||||
new_obj.clear()
|
||||
for k, val in temp_dict.items():
|
||||
new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
try:
|
||||
if hasattr(obj, '__slots__'):
|
||||
for attr_name in new_obj.__slots__:
|
||||
setattr(
|
||||
new_obj,
|
||||
attr_name,
|
||||
cls._replace_bot(
|
||||
cls._replace_bot(getattr(new_obj, attr_name), memo), memo
|
||||
),
|
||||
)
|
||||
if '__dict__' in obj.__slots__:
|
||||
# In this case, we have already covered the case that obj has __dict__
|
||||
# Note that obj may have a __dict__ even if it's not in __slots__!
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__dict__'):
|
||||
for attr_name, attr in new_obj.__dict__.items():
|
||||
setattr(new_obj, attr_name, cls._replace_bot(attr, memo))
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
except Exception as exception:
|
||||
warn(
|
||||
f'Parsing of an object failed with the following exception: {exception}. '
|
||||
f'See the docs of BasePersistence.replace_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
def insert_bot(self, obj: object) -> object:
|
||||
"""
|
||||
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
|
||||
:paramref:`bot`. Currently, this handles objects of type :class:`list`,
|
||||
:class:`tuple`, :class:`set`, :class:`frozenset`, :class:`dict`,
|
||||
:class:`collections.defaultdict` and objects that have a :attr:`~object.__dict__` or
|
||||
:data:`~object.__slots__` attribute, excluding classes and objects that can't be copied
|
||||
with :func:`copy.copy`. If the parsing of an object fails, the object will be returned
|
||||
unchanged and the error will be logged.
|
||||
|
||||
Args:
|
||||
obj (:obj:`object`): The object
|
||||
|
||||
Returns:
|
||||
:class:`object`: Copy of the object with Bot instances inserted.
|
||||
"""
|
||||
return self._insert_bot(obj, {})
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def _insert_bot(self, obj: object, memo: Dict[int, object]) -> object:
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, Bot):
|
||||
memo[obj_id] = self.bot
|
||||
return self.bot
|
||||
if isinstance(obj, str) and obj == self.REPLACED_BOT:
|
||||
memo[obj_id] = self.bot
|
||||
return self.bot
|
||||
if isinstance(obj, (list, set)):
|
||||
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
|
||||
# over doesn't change its length during the iteration
|
||||
temp_iterable = obj.copy()
|
||||
new_iterable = obj.__class__(self._insert_bot(item, memo) for item in temp_iterable)
|
||||
memo[obj_id] = new_iterable
|
||||
return new_iterable
|
||||
if isinstance(obj, (tuple, frozenset)):
|
||||
# tuples and frozensets are immutable so we don't need to worry about thread safety
|
||||
new_immutable = obj.__class__(self._insert_bot(item, memo) for item in obj)
|
||||
memo[obj_id] = new_immutable
|
||||
return new_immutable
|
||||
if isinstance(obj, type):
|
||||
# classes usually do have a __dict__, but it's not writable
|
||||
warn(
|
||||
f'BasePersistence.insert_bot does not handle classes such as {obj.__name__!r}. '
|
||||
'See the docs of BasePersistence.insert_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
return obj
|
||||
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
except Exception:
|
||||
warn(
|
||||
'BasePersistence.insert_bot does not handle objects that can not be copied. See '
|
||||
'the docs of BasePersistence.insert_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
if isinstance(obj, dict):
|
||||
# We handle dicts via copy(obj) so we don't have to make a
|
||||
# difference between dict and defaultdict
|
||||
new_obj = cast(dict, new_obj)
|
||||
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
|
||||
# change during the iteration
|
||||
temp_dict = new_obj.copy()
|
||||
new_obj.clear()
|
||||
for k, val in temp_dict.items():
|
||||
new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
try:
|
||||
if hasattr(obj, '__slots__'):
|
||||
for attr_name in obj.__slots__:
|
||||
setattr(
|
||||
new_obj,
|
||||
attr_name,
|
||||
self._insert_bot(
|
||||
self._insert_bot(getattr(new_obj, attr_name), memo), memo
|
||||
),
|
||||
)
|
||||
if '__dict__' in obj.__slots__:
|
||||
# In this case, we have already covered the case that obj has __dict__
|
||||
# Note that obj may have a __dict__ even if it's not in __slots__!
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__dict__'):
|
||||
for attr_name, attr in new_obj.__dict__.items():
|
||||
setattr(new_obj, attr_name, self._insert_bot(attr, memo))
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
except Exception as exception:
|
||||
warn(
|
||||
f'Parsing of an object failed with the following exception: {exception}. '
|
||||
f'See the docs of BasePersistence.insert_bot for more information.',
|
||||
PTBRuntimeWarning,
|
||||
)
|
||||
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
@abstractmethod
|
||||
def get_user_data(self) -> Dict[int, UD]:
|
||||
"""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
|
||||
|
@ -628,6 +360,3 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
|
|||
.. versionchanged:: 14.0
|
||||
Changed this method into an ``@abstractmethod``.
|
||||
"""
|
||||
|
||||
REPLACED_BOT: ClassVar[str] = 'bot_instance_replaced_by_ptb_persistence'
|
||||
""":obj:`str`: Placeholder for :class:`telegram.Bot` instances replaced in saved data."""
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
from typing import Dict, Optional, Tuple, cast
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from telegram.ext import BasePersistence, PersistenceInput
|
||||
from telegram._utils.types import JSONDict
|
||||
from telegram.ext._utils.types import ConversationDict, CDCData
|
||||
|
@ -31,7 +33,7 @@ except ImportError:
|
|||
|
||||
|
||||
class DictPersistence(BasePersistence):
|
||||
"""Using Python's :obj:`dict` and ``json`` for making your bot persistent.
|
||||
"""Using Python's :obj:`dict` and :mod:`json` for making your bot persistent.
|
||||
|
||||
Attention:
|
||||
The interface provided by this class is intended to be accessed exclusively by
|
||||
|
@ -39,19 +41,13 @@ class DictPersistence(BasePersistence):
|
|||
interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`.
|
||||
|
||||
Note:
|
||||
This class does *not* implement a :meth:`flush` method, meaning that data managed by
|
||||
``DictPersistence`` is in-memory only and will be lost when the bot shuts down. This is,
|
||||
because ``DictPersistence`` is mainly intended as starting point for custom persistence
|
||||
classes that need to JSON-serialize the stored data before writing them to file/database.
|
||||
* Data managed by :class:`DictPersistence` is in-memory only and will be lost when the bot
|
||||
shuts down. This is, because :class:`DictPersistence` is mainly intended as starting
|
||||
point for custom persistence classes that need to JSON-serialize the stored data before
|
||||
writing them to file/database.
|
||||
|
||||
Warning:
|
||||
:class:`DictPersistence` will try to replace :class:`telegram.Bot` instances by
|
||||
:attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with
|
||||
:meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure
|
||||
that changes to the bot apply to the saved objects, too. If you change the bots token, this
|
||||
may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
|
||||
:meth:`telegram.ext.BasePersistence.replace_bot` and
|
||||
:meth:`telegram.ext.BasePersistence.insert_bot`.
|
||||
* This implementation of :class:`BasePersistence` does not handle data that cannot be
|
||||
serialized by :func:`json.dumps`.
|
||||
|
||||
.. versionchanged:: 14.0
|
||||
The parameters and attributes ``store_*_data`` were replaced by :attr:`store_data`.
|
||||
|
@ -244,7 +240,7 @@ class DictPersistence(BasePersistence):
|
|||
"""
|
||||
if self.user_data is None:
|
||||
self._user_data = {}
|
||||
return self.user_data # type: ignore[return-value]
|
||||
return deepcopy(self.user_data) # type: ignore[arg-type]
|
||||
|
||||
def get_chat_data(self) -> Dict[int, Dict[object, object]]:
|
||||
"""Returns the chat_data created from the ``chat_data_json`` or an empty :obj:`dict`.
|
||||
|
@ -254,7 +250,7 @@ class DictPersistence(BasePersistence):
|
|||
"""
|
||||
if self.chat_data is None:
|
||||
self._chat_data = {}
|
||||
return self.chat_data # type: ignore[return-value]
|
||||
return deepcopy(self.chat_data) # type: ignore[arg-type]
|
||||
|
||||
def get_bot_data(self) -> Dict[object, object]:
|
||||
"""Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`.
|
||||
|
@ -264,7 +260,7 @@ class DictPersistence(BasePersistence):
|
|||
"""
|
||||
if self.bot_data is None:
|
||||
self._bot_data = {}
|
||||
return self.bot_data # type: ignore[return-value]
|
||||
return deepcopy(self.bot_data) # type: ignore[arg-type]
|
||||
|
||||
def get_callback_data(self) -> Optional[CDCData]:
|
||||
"""Returns the callback_data created from the ``callback_data_json`` or :obj:`None`.
|
||||
|
@ -279,7 +275,7 @@ class DictPersistence(BasePersistence):
|
|||
if self.callback_data is None:
|
||||
self._callback_data = None
|
||||
return None
|
||||
return self.callback_data[0], self.callback_data[1].copy()
|
||||
return deepcopy((self.callback_data[0], self.callback_data[1].copy()))
|
||||
|
||||
def get_conversations(self, name: str) -> ConversationDict:
|
||||
"""Returns the conversations created from the ``conversations_json`` or an empty
|
||||
|
@ -320,7 +316,7 @@ class DictPersistence(BasePersistence):
|
|||
self._user_data = {}
|
||||
if self._user_data.get(user_id) == data:
|
||||
return
|
||||
self._user_data[user_id] = data
|
||||
self._user_data[user_id] = deepcopy(data)
|
||||
self._user_data_json = None
|
||||
|
||||
def update_chat_data(self, chat_id: int, data: Dict) -> None:
|
||||
|
@ -334,7 +330,7 @@ class DictPersistence(BasePersistence):
|
|||
self._chat_data = {}
|
||||
if self._chat_data.get(chat_id) == data:
|
||||
return
|
||||
self._chat_data[chat_id] = data
|
||||
self._chat_data[chat_id] = deepcopy(data)
|
||||
self._chat_data_json = None
|
||||
|
||||
def update_bot_data(self, data: Dict) -> None:
|
||||
|
@ -345,7 +341,7 @@ class DictPersistence(BasePersistence):
|
|||
"""
|
||||
if self._bot_data == data:
|
||||
return
|
||||
self._bot_data = data
|
||||
self._bot_data = deepcopy(data)
|
||||
self._bot_data_json = None
|
||||
|
||||
def update_callback_data(self, data: CDCData) -> None:
|
||||
|
@ -360,7 +356,7 @@ class DictPersistence(BasePersistence):
|
|||
"""
|
||||
if self._callback_data == data:
|
||||
return
|
||||
self._callback_data = (data[0], data[1].copy())
|
||||
self._callback_data = deepcopy((data[0], data[1].copy()))
|
||||
self._callback_data_json = None
|
||||
|
||||
def drop_chat_data(self, chat_id: int) -> None:
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""This module contains the PicklePersistence class."""
|
||||
import copyreg
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from sys import version_info as py_ver
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
|
@ -26,14 +29,105 @@ from typing import (
|
|||
Tuple,
|
||||
overload,
|
||||
cast,
|
||||
Type,
|
||||
Set,
|
||||
Callable,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from telegram import Bot, TelegramObject
|
||||
from telegram._utils.types import FilePathInput
|
||||
from telegram._utils.warnings import warn
|
||||
from telegram.ext import BasePersistence, PersistenceInput
|
||||
from telegram.ext._contexttypes import ContextTypes
|
||||
from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData
|
||||
|
||||
|
||||
_REPLACED_KNOWN_BOT = "a known bot replaced by PTB's PicklePersistence"
|
||||
_REPLACED_UNKNOWN_BOT = "an unknown bot replaced by PTB's PicklePersistence"
|
||||
|
||||
TO = TypeVar('TO', bound=TelegramObject)
|
||||
|
||||
|
||||
def _all_subclasses(cls: Type[TO]) -> Set[Type[TO]]:
|
||||
"""Gets all subclasses of the specified object, recursively. from
|
||||
https://stackoverflow.com/a/3862957/9706202
|
||||
"""
|
||||
subclasses = cls.__subclasses__()
|
||||
return set(subclasses).union([s for c in subclasses for s in _all_subclasses(c)])
|
||||
|
||||
|
||||
def _reconstruct_to(cls: Type[TO], kwargs: dict) -> TO:
|
||||
"""
|
||||
This method is used for unpickling. The data, which is in the form a dictionary, is
|
||||
converted back into a class. Works mostly the same as :meth:`TelegramObject.__setstate__`.
|
||||
This function should be kept in place for backwards compatibility even if the pickling logic
|
||||
is changed, since `_custom_reduction` places references to this function into the pickled data.
|
||||
"""
|
||||
obj = cls.__new__(cls)
|
||||
obj.__setstate__(kwargs)
|
||||
return obj # type: ignore[return-value]
|
||||
|
||||
|
||||
def _custom_reduction(cls: TO) -> Tuple[Callable, Tuple[Type[TO], dict]]:
|
||||
"""
|
||||
This method is used for pickling. The bot attribute is preserved so _BotPickler().persistent_id
|
||||
works as intended.
|
||||
"""
|
||||
data = cls._get_attrs(include_private=True) # pylint: disable=protected-access
|
||||
return _reconstruct_to, (cls.__class__, data)
|
||||
|
||||
|
||||
class _BotPickler(pickle.Pickler):
|
||||
__slots__ = ('_bot',)
|
||||
|
||||
def __init__(self, bot: Bot, *args: Any, **kwargs: Any):
|
||||
self._bot = bot
|
||||
if py_ver < (3, 8): # self.reducer_override is used above this version
|
||||
# Here we define a private dispatch_table, because we want to preserve the bot
|
||||
# attribute of objects so persistent_id works as intended. Otherwise, the bot attribute
|
||||
# is deleted in __getstate__, which is used during regular pickling (via pickle.dumps)
|
||||
self.dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
|
||||
for obj in _all_subclasses(TelegramObject):
|
||||
self.dispatch_table[obj] = _custom_reduction # type: ignore[index]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reducer_override( # pylint: disable=no-self-use
|
||||
self, obj: TO
|
||||
) -> Tuple[Callable, Tuple[Type[TO], dict]]:
|
||||
if not isinstance(obj, TelegramObject):
|
||||
return NotImplemented
|
||||
|
||||
return _custom_reduction(obj)
|
||||
|
||||
def persistent_id(self, obj: object) -> Optional[str]:
|
||||
"""Used to 'mark' the Bot, so it can be replaced later. See
|
||||
https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id for more info
|
||||
"""
|
||||
if obj is self._bot:
|
||||
return _REPLACED_KNOWN_BOT
|
||||
if isinstance(obj, Bot):
|
||||
warn('Unknown bot instance found. Will be replaced by `None` during unpickling')
|
||||
return _REPLACED_UNKNOWN_BOT
|
||||
return None # pickles as usual
|
||||
|
||||
|
||||
class _BotUnpickler(pickle.Unpickler):
|
||||
__slots__ = ('_bot',)
|
||||
|
||||
def __init__(self, bot: Bot, *args: Any, **kwargs: Any):
|
||||
self._bot = bot
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def persistent_load(self, pid: str) -> Optional[Bot]:
|
||||
"""Replaces the bot with the current bot if known, else it is replaced by :obj:`None`."""
|
||||
if pid == _REPLACED_KNOWN_BOT:
|
||||
return self._bot
|
||||
if pid == _REPLACED_UNKNOWN_BOT:
|
||||
return None
|
||||
raise pickle.UnpicklingError("Found unknown persistent id when unpickling!")
|
||||
|
||||
|
||||
class PicklePersistence(BasePersistence[UD, CD, BD]):
|
||||
"""Using python's builtin pickle for making your bot persistent.
|
||||
|
||||
|
@ -42,14 +136,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
:class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might
|
||||
interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`.
|
||||
|
||||
Warning:
|
||||
:class:`PicklePersistence` will try to replace :class:`telegram.Bot` instances by
|
||||
:attr:`~telegram.ext.BasePersistence.REPLACED_BOT` and insert the bot set with
|
||||
:meth:`telegram.ext.BasePersistence.set_bot` upon loading of the data. This is to ensure
|
||||
that changes to the bot apply to the saved objects, too. If you change the bots token, this
|
||||
may lead to e.g. ``Chat not found`` errors. For the limitations on replacing bots see
|
||||
:meth:`telegram.ext.BasePersistence.replace_bot` and
|
||||
:meth:`telegram.ext.BasePersistence.insert_bot`.
|
||||
Note:
|
||||
This implementation of :class:`BasePersistence` uses the functionality of the pickle module
|
||||
to support serialization of bot instances. Specifically any reference to
|
||||
:attr:`~BasePersistence.bot` will be replaced by a placeholder before pickling and
|
||||
:attr:`~BasePersistence.bot` will be inserted back when loading the data.
|
||||
|
||||
.. versionchanged:: 14.0
|
||||
|
||||
|
@ -57,7 +148,6 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
* The parameter and attribute ``filename`` were replaced by :attr:`filepath`.
|
||||
* :attr:`filepath` now also accepts :obj:`pathlib.Path` as argument.
|
||||
|
||||
|
||||
Args:
|
||||
filepath (:obj:`str` | :obj:`pathlib.Path`): The filepath for storing the pickle files.
|
||||
When :attr:`single_file` is :obj:`False` this will be used as a prefix.
|
||||
|
@ -151,7 +241,8 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
def _load_singlefile(self) -> None:
|
||||
try:
|
||||
with self.filepath.open("rb") as file:
|
||||
data = pickle.load(file)
|
||||
data = _BotUnpickler(self.bot, file).load()
|
||||
|
||||
self.user_data = data['user_data']
|
||||
self.chat_data = data['chat_data']
|
||||
# For backwards compatibility with files not containing bot data
|
||||
|
@ -170,11 +261,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
except Exception as exc:
|
||||
raise TypeError(f"Something went wrong unpickling {self.filepath.name}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _load_file(filepath: Path) -> Any:
|
||||
def _load_file(self, filepath: Path) -> Any:
|
||||
try:
|
||||
with filepath.open("rb") as file:
|
||||
return pickle.load(file)
|
||||
return _BotUnpickler(self.bot, file).load()
|
||||
|
||||
except OSError:
|
||||
return None
|
||||
except pickle.UnpicklingError as exc:
|
||||
|
@ -191,12 +282,11 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
'callback_data': self.callback_data,
|
||||
}
|
||||
with self.filepath.open("wb") as file:
|
||||
pickle.dump(data, file)
|
||||
_BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data)
|
||||
|
||||
@staticmethod
|
||||
def _dump_file(filepath: Path, data: object) -> None:
|
||||
def _dump_file(self, filepath: Path, data: object) -> None:
|
||||
with filepath.open("wb") as file:
|
||||
pickle.dump(data, file)
|
||||
_BotPickler(self.bot, file, protocol=pickle.HIGHEST_PROTOCOL).dump(data)
|
||||
|
||||
def get_user_data(self) -> Dict[int, UD]:
|
||||
"""Returns the user_data from the pickle file if it exists or an empty :obj:`dict`.
|
||||
|
@ -213,7 +303,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self.user_data = data
|
||||
else:
|
||||
self._load_singlefile()
|
||||
return self.user_data # type: ignore[return-value]
|
||||
return deepcopy(self.user_data) # type: ignore[arg-type]
|
||||
|
||||
def get_chat_data(self) -> Dict[int, CD]:
|
||||
"""Returns the chat_data from the pickle file if it exists or an empty :obj:`dict`.
|
||||
|
@ -230,7 +320,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self.chat_data = data
|
||||
else:
|
||||
self._load_singlefile()
|
||||
return self.chat_data # type: ignore[return-value]
|
||||
return deepcopy(self.chat_data) # type: ignore[arg-type]
|
||||
|
||||
def get_bot_data(self) -> BD:
|
||||
"""Returns the bot_data from the pickle file if it exists or an empty object of type
|
||||
|
@ -248,7 +338,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self.bot_data = data
|
||||
else:
|
||||
self._load_singlefile()
|
||||
return self.bot_data # type: ignore[return-value]
|
||||
return deepcopy(self.bot_data) # type: ignore[return-value]
|
||||
|
||||
def get_callback_data(self) -> Optional[CDCData]:
|
||||
"""Returns the callback data from the pickle file if it exists or :obj:`None`.
|
||||
|
@ -271,7 +361,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self._load_singlefile()
|
||||
if self.callback_data is None:
|
||||
return None
|
||||
return self.callback_data[0], self.callback_data[1].copy()
|
||||
return deepcopy((self.callback_data[0], self.callback_data[1].copy()))
|
||||
|
||||
def get_conversations(self, name: str) -> ConversationDict:
|
||||
"""Returns the conversations from the pickle file if it exists or an empty dict.
|
||||
|
@ -326,7 +416,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self.user_data = {}
|
||||
if self.user_data.get(user_id) == data:
|
||||
return
|
||||
self.user_data[user_id] = data
|
||||
self.user_data[user_id] = deepcopy(data)
|
||||
if not self.on_flush:
|
||||
if not self.single_file:
|
||||
self._dump_file(Path(f"{self.filepath}_user_data"), self.user_data)
|
||||
|
@ -344,7 +434,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
self.chat_data = {}
|
||||
if self.chat_data.get(chat_id) == data:
|
||||
return
|
||||
self.chat_data[chat_id] = data
|
||||
self.chat_data[chat_id] = deepcopy(data)
|
||||
if not self.on_flush:
|
||||
if not self.single_file:
|
||||
self._dump_file(Path(f"{self.filepath}_chat_data"), self.chat_data)
|
||||
|
@ -360,7 +450,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
"""
|
||||
if self.bot_data == data:
|
||||
return
|
||||
self.bot_data = data
|
||||
self.bot_data = deepcopy(data)
|
||||
if not self.on_flush:
|
||||
if not self.single_file:
|
||||
self._dump_file(Path(f"{self.filepath}_bot_data"), self.bot_data)
|
||||
|
@ -380,7 +470,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
|
|||
"""
|
||||
if self.callback_data == data:
|
||||
return
|
||||
self.callback_data = (data[0], data[1].copy())
|
||||
self.callback_data = deepcopy((data[0], data[1].copy()))
|
||||
if not self.on_flush:
|
||||
if not self.single_file:
|
||||
self._dump_file(Path(f"{self.filepath}_callback_data"), self.callback_data)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
import datetime as dtm
|
||||
from collections import defaultdict
|
||||
|
@ -241,6 +242,10 @@ class TestBot:
|
|||
if bot.last_name:
|
||||
assert to_dict_bot["last_name"] == bot.last_name
|
||||
|
||||
def test_bot_pickling_error(self, bot):
|
||||
with pytest.raises(pickle.PicklingError, match="Bot objects cannot be pickled"):
|
||||
pickle.dumps(bot)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'bot_method_name',
|
||||
argvalues=[
|
||||
|
|
|
@ -24,7 +24,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
import pytz
|
||||
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User, Chat
|
||||
from telegram.ext._callbackdatacache import (
|
||||
CallbackDataCache,
|
||||
_KeyboardData,
|
||||
|
@ -159,7 +159,8 @@ class TestCallbackDataCache:
|
|||
if invalid:
|
||||
callback_data_cache.clear_callback_data()
|
||||
|
||||
effective_message = Message(message_id=1, date=None, chat=None, reply_markup=out)
|
||||
chat = Chat(1, 'private')
|
||||
effective_message = Message(message_id=1, date=datetime.now(), chat=chat, reply_markup=out)
|
||||
effective_message.reply_to_message = deepcopy(effective_message)
|
||||
effective_message.pinned_message = deepcopy(effective_message)
|
||||
cq_id = uuid4().hex
|
||||
|
|
|
@ -16,26 +16,25 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import gzip
|
||||
import signal
|
||||
import uuid
|
||||
from collections.abc import Container
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from threading import Lock
|
||||
|
||||
import pytest
|
||||
|
||||
from telegram.warnings import PTBUserWarning
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
from telegram import Update, Message, User, Chat, MessageEntity, Bot
|
||||
from telegram import Update, Message, User, Chat, MessageEntity, Bot, TelegramObject
|
||||
from telegram.ext import (
|
||||
BasePersistence,
|
||||
ConversationHandler,
|
||||
|
@ -130,7 +129,7 @@ def base_persistence():
|
|||
@pytest.fixture(scope="function")
|
||||
def bot_persistence():
|
||||
class BotPersistence(BasePersistence):
|
||||
__slots__ = ()
|
||||
__slots__ = ('bot_data', 'chat_data', 'user_data', 'callback_data')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -252,7 +251,6 @@ class TestBasePersistence:
|
|||
inst = bot_persistence
|
||||
for attr in inst.__slots__:
|
||||
assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'"
|
||||
# assert not inst.__dict__, f"got missing slot(s): {inst.__dict__}"
|
||||
# The below test fails if the child class doesn't define __slots__ (not a cause of concern)
|
||||
assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot"
|
||||
|
||||
|
@ -590,292 +588,6 @@ class TestBasePersistence:
|
|||
dp.process_update(MyUpdate())
|
||||
assert 'An uncaught error was raised while processing the update' not in caplog.text
|
||||
|
||||
def test_bot_replace_insert_bot(self, bot, bot_persistence):
|
||||
class CustomSlottedClass:
|
||||
__slots__ = ('bot', '__dict__')
|
||||
|
||||
def __init__(self):
|
||||
self.bot = bot
|
||||
self.not_in_slots = bot
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CustomSlottedClass):
|
||||
return self.bot is other.bot and self.not_in_slots is other.not_in_slots
|
||||
return False
|
||||
|
||||
class DictNotInSlots(Container):
|
||||
"""This classes parent has slots, but __dict__ is not in those slots."""
|
||||
|
||||
def __init__(self):
|
||||
self.bot = bot
|
||||
|
||||
def __contains__(self, item):
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, DictNotInSlots):
|
||||
return self.bot is other.bot
|
||||
return False
|
||||
|
||||
class CustomClass:
|
||||
def __init__(self):
|
||||
self.bot = bot
|
||||
self.slotted_object = CustomSlottedClass()
|
||||
self.dict_not_in_slots_object = DictNotInSlots()
|
||||
self.list_ = [1, 2, bot]
|
||||
self.tuple_ = tuple(self.list_)
|
||||
self.set_ = set(self.list_)
|
||||
self.frozenset_ = frozenset(self.list_)
|
||||
self.dict_ = {item: item for item in self.list_}
|
||||
self.defaultdict_ = defaultdict(dict, self.dict_)
|
||||
|
||||
@staticmethod
|
||||
def replace_bot():
|
||||
cc = CustomClass()
|
||||
cc.bot = BasePersistence.REPLACED_BOT
|
||||
cc.slotted_object.bot = BasePersistence.REPLACED_BOT
|
||||
cc.slotted_object.not_in_slots = BasePersistence.REPLACED_BOT
|
||||
cc.dict_not_in_slots_object.bot = BasePersistence.REPLACED_BOT
|
||||
cc.list_ = [1, 2, BasePersistence.REPLACED_BOT]
|
||||
cc.tuple_ = tuple(cc.list_)
|
||||
cc.set_ = set(cc.list_)
|
||||
cc.frozenset_ = frozenset(cc.list_)
|
||||
cc.dict_ = {item: item for item in cc.list_}
|
||||
cc.defaultdict_ = defaultdict(dict, cc.dict_)
|
||||
return cc
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CustomClass):
|
||||
return (
|
||||
self.bot is other.bot
|
||||
and self.slotted_object == other.slotted_object
|
||||
and self.dict_not_in_slots_object == other.dict_not_in_slots_object
|
||||
and self.list_ == other.list_
|
||||
and self.tuple_ == other.tuple_
|
||||
and self.set_ == other.set_
|
||||
and self.frozenset_ == other.frozenset_
|
||||
and self.dict_ == other.dict_
|
||||
and self.defaultdict_ == other.defaultdict_
|
||||
)
|
||||
return False
|
||||
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
cc = CustomClass()
|
||||
|
||||
persistence.update_bot_data({1: cc})
|
||||
assert persistence.bot_data[1].bot == BasePersistence.REPLACED_BOT
|
||||
assert persistence.bot_data[1] == cc.replace_bot()
|
||||
|
||||
persistence.update_chat_data(123, {1: cc})
|
||||
assert persistence.chat_data[123][1].bot == BasePersistence.REPLACED_BOT
|
||||
assert persistence.chat_data[123][1] == cc.replace_bot()
|
||||
|
||||
persistence.update_user_data(123, {1: cc})
|
||||
assert persistence.user_data[123][1].bot == BasePersistence.REPLACED_BOT
|
||||
assert persistence.user_data[123][1] == cc.replace_bot()
|
||||
|
||||
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
|
||||
assert persistence.callback_data[0][0][2][0].bot == BasePersistence.REPLACED_BOT
|
||||
assert persistence.callback_data[0][0][2][0] == cc.replace_bot()
|
||||
|
||||
assert persistence.get_bot_data()[1] == cc
|
||||
assert persistence.get_bot_data()[1].bot is bot
|
||||
assert persistence.get_chat_data()[123][1] == cc
|
||||
assert persistence.get_chat_data()[123][1].bot is bot
|
||||
assert persistence.get_user_data()[123][1] == cc
|
||||
assert persistence.get_user_data()[123][1].bot is bot
|
||||
assert persistence.get_callback_data()[0][0][2][0].bot is bot
|
||||
assert persistence.get_callback_data()[0][0][2][0] == cc
|
||||
|
||||
def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn):
|
||||
"""Here check that unpickable objects are just returned verbatim."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomClass:
|
||||
def __copy__(self):
|
||||
raise TypeError('UnhandledException')
|
||||
|
||||
lock = Lock()
|
||||
|
||||
persistence.update_bot_data({1: lock})
|
||||
assert persistence.bot_data[1] is lock
|
||||
persistence.update_chat_data(123, {1: lock})
|
||||
assert persistence.chat_data[123][1] is lock
|
||||
persistence.update_user_data(123, {1: lock})
|
||||
assert persistence.user_data[123][1] is lock
|
||||
persistence.update_callback_data(([('1', 2, {0: lock})], {'1': '2'}))
|
||||
assert persistence.callback_data[0][0][2][0] is lock
|
||||
|
||||
assert persistence.get_bot_data()[1] is lock
|
||||
assert persistence.get_chat_data()[123][1] is lock
|
||||
assert persistence.get_user_data()[123][1] is lock
|
||||
assert persistence.get_callback_data()[0][0][2][0] is lock
|
||||
|
||||
cc = CustomClass()
|
||||
|
||||
persistence.update_bot_data({1: cc})
|
||||
assert persistence.bot_data[1] is cc
|
||||
persistence.update_chat_data(123, {1: cc})
|
||||
assert persistence.chat_data[123][1] is cc
|
||||
persistence.update_user_data(123, {1: cc})
|
||||
assert persistence.user_data[123][1] is cc
|
||||
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
|
||||
assert persistence.callback_data[0][0][2][0] is cc
|
||||
|
||||
assert persistence.get_bot_data()[1] is cc
|
||||
assert persistence.get_chat_data()[123][1] is cc
|
||||
assert persistence.get_user_data()[123][1] is cc
|
||||
assert persistence.get_callback_data()[0][0][2][0] is cc
|
||||
|
||||
assert len(recwarn) == 2
|
||||
assert str(recwarn[0].message).startswith(
|
||||
"BasePersistence.replace_bot does not handle objects that can not be copied."
|
||||
)
|
||||
assert str(recwarn[1].message).startswith(
|
||||
"BasePersistence.insert_bot does not handle objects that can not be copied."
|
||||
)
|
||||
|
||||
def test_bot_replace_insert_bot_unparsable_objects(self, bot, bot_persistence, recwarn):
|
||||
"""Here check that objects in __dict__ or __slots__ that can't
|
||||
be parsed are just returned verbatim."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
uuid_obj = uuid.uuid4()
|
||||
|
||||
persistence.update_bot_data({1: uuid_obj})
|
||||
assert persistence.bot_data[1] is uuid_obj
|
||||
persistence.update_chat_data(123, {1: uuid_obj})
|
||||
assert persistence.chat_data[123][1] is uuid_obj
|
||||
persistence.update_user_data(123, {1: uuid_obj})
|
||||
assert persistence.user_data[123][1] is uuid_obj
|
||||
persistence.update_callback_data(([('1', 2, {0: uuid_obj})], {'1': '2'}))
|
||||
assert persistence.callback_data[0][0][2][0] is uuid_obj
|
||||
|
||||
assert persistence.get_bot_data()[1] is uuid_obj
|
||||
assert persistence.get_chat_data()[123][1] is uuid_obj
|
||||
assert persistence.get_user_data()[123][1] is uuid_obj
|
||||
assert persistence.get_callback_data()[0][0][2][0] is uuid_obj
|
||||
|
||||
assert len(recwarn) == 2
|
||||
assert str(recwarn[0].message).startswith(
|
||||
"Parsing of an object failed with the following exception: "
|
||||
)
|
||||
assert str(recwarn[1].message).startswith(
|
||||
"Parsing of an object failed with the following exception: "
|
||||
)
|
||||
|
||||
def test_bot_replace_insert_bot_classes(self, bot, bot_persistence, recwarn):
|
||||
"""Here check that classes are just returned verbatim."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomClass:
|
||||
pass
|
||||
|
||||
persistence.update_bot_data({1: CustomClass})
|
||||
assert persistence.bot_data[1] is CustomClass
|
||||
persistence.update_chat_data(123, {1: CustomClass})
|
||||
assert persistence.chat_data[123][1] is CustomClass
|
||||
persistence.update_user_data(123, {1: CustomClass})
|
||||
assert persistence.user_data[123][1] is CustomClass
|
||||
|
||||
assert persistence.get_bot_data()[1] is CustomClass
|
||||
assert persistence.get_chat_data()[123][1] is CustomClass
|
||||
assert persistence.get_user_data()[123][1] is CustomClass
|
||||
|
||||
assert len(recwarn) == 2
|
||||
assert str(recwarn[0].message).startswith(
|
||||
"BasePersistence.replace_bot does not handle classes such as 'CustomClass'"
|
||||
)
|
||||
assert str(recwarn[1].message).startswith(
|
||||
"BasePersistence.insert_bot does not handle classes such as 'CustomClass'"
|
||||
)
|
||||
|
||||
def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence):
|
||||
"""Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomClass:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __eq__(self, other):
|
||||
raise RuntimeError("Can't be compared")
|
||||
|
||||
cc = CustomClass({1: bot, 2: 'foo'})
|
||||
expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'}
|
||||
|
||||
persistence.update_bot_data({1: cc})
|
||||
assert persistence.bot_data[1].data == expected
|
||||
persistence.update_chat_data(123, {1: cc})
|
||||
assert persistence.chat_data[123][1].data == expected
|
||||
persistence.update_user_data(123, {1: cc})
|
||||
assert persistence.user_data[123][1].data == expected
|
||||
persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'}))
|
||||
assert persistence.callback_data[0][0][2][0].data == expected
|
||||
|
||||
expected = {1: bot, 2: 'foo'}
|
||||
|
||||
assert persistence.get_bot_data()[1].data == expected
|
||||
assert persistence.get_chat_data()[123][1].data == expected
|
||||
assert persistence.get_user_data()[123][1].data == expected
|
||||
assert persistence.get_callback_data()[0][0][2][0].data == expected
|
||||
|
||||
@pytest.mark.filterwarnings('ignore:BasePersistence')
|
||||
def test_replace_insert_bot_item_identity(self, bot, bot_persistence):
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomSlottedClass:
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self):
|
||||
self.value = 5
|
||||
|
||||
class CustomClass:
|
||||
pass
|
||||
|
||||
slot_object = CustomSlottedClass()
|
||||
dict_object = CustomClass()
|
||||
lock = Lock()
|
||||
list_ = [slot_object, dict_object, lock]
|
||||
tuple_ = (1, 2, 3)
|
||||
dict_ = {1: slot_object, 2: dict_object}
|
||||
|
||||
data = {
|
||||
'bot_1': bot,
|
||||
'bot_2': bot,
|
||||
'list_1': list_,
|
||||
'list_2': list_,
|
||||
'tuple_1': tuple_,
|
||||
'tuple_2': tuple_,
|
||||
'dict_1': dict_,
|
||||
'dict_2': dict_,
|
||||
}
|
||||
|
||||
def make_assertion(data_):
|
||||
return (
|
||||
data_['bot_1'] is data_['bot_2']
|
||||
and data_['list_1'] is data_['list_2']
|
||||
and data_['list_1'][0] is data_['list_2'][0]
|
||||
and data_['list_1'][1] is data_['list_2'][1]
|
||||
and data_['list_1'][2] is data_['list_2'][2]
|
||||
and data_['tuple_1'] is data_['tuple_2']
|
||||
and data_['dict_1'] is data_['dict_2']
|
||||
and data_['dict_1'][1] is data_['dict_2'][1]
|
||||
and data_['dict_1'][1] is data_['list_1'][0]
|
||||
and data_['dict_1'][2] is data_['list_1'][1]
|
||||
and data_['dict_1'][2] is data_['dict_2'][2]
|
||||
)
|
||||
|
||||
persistence.update_bot_data(data)
|
||||
assert make_assertion(persistence.bot_data)
|
||||
assert make_assertion(persistence.get_bot_data())
|
||||
|
||||
def test_set_bot_exception(self, bot):
|
||||
non_ext_bot = Bot(bot.token)
|
||||
persistence = OwnPersistence()
|
||||
|
@ -1033,11 +745,28 @@ def pickle_files_wo_callback_data(user_data, chat_data, bot_data, conversations)
|
|||
def update(bot):
|
||||
user = User(id=321, first_name='test_user', is_bot=False)
|
||||
chat = Chat(id=123, type='group')
|
||||
message = Message(1, None, chat, from_user=user, text="Hi there", bot=bot)
|
||||
message = Message(1, datetime.datetime.now(), chat, from_user=user, text="Hi there", bot=bot)
|
||||
return Update(0, message=message)
|
||||
|
||||
|
||||
class TestPicklePersistence:
|
||||
class DictSub(TelegramObject): # Used for testing our custom (Un)Pickler.
|
||||
def __init__(self, private, normal, b):
|
||||
self._private = private
|
||||
self.normal = normal
|
||||
self._bot = b
|
||||
|
||||
class SlotsSub(TelegramObject):
|
||||
__slots__ = ('new_var', '_private')
|
||||
|
||||
def __init__(self, new_var, private):
|
||||
self.new_var = new_var
|
||||
self._private = private
|
||||
|
||||
class NormalClass:
|
||||
def __init__(self, my_var):
|
||||
self.my_var = my_var
|
||||
|
||||
def test_slot_behaviour(self, mro_slots, pickle_persistence):
|
||||
inst = pickle_persistence
|
||||
for attr in inst.__slots__:
|
||||
|
@ -1046,7 +775,7 @@ class TestPicklePersistence:
|
|||
|
||||
def test_pickle_behaviour_with_slots(self, pickle_persistence):
|
||||
bot_data = pickle_persistence.get_bot_data()
|
||||
bot_data['message'] = Message(3, None, Chat(2, type='supergroup'))
|
||||
bot_data['message'] = Message(3, datetime.datetime.now(), Chat(2, type='supergroup'))
|
||||
pickle_persistence.update_bot_data(bot_data)
|
||||
retrieved = pickle_persistence.get_bot_data()
|
||||
assert retrieved == bot_data
|
||||
|
@ -1633,13 +1362,113 @@ class TestPicklePersistence:
|
|||
conversations_test = dict(pickle.load(f))['conversations']
|
||||
assert conversations_test['name1'] == conversation1
|
||||
|
||||
def test_custom_pickler_unpickler_simple(
|
||||
self, pickle_persistence, update, good_pickle_files, bot, recwarn
|
||||
):
|
||||
pickle_persistence.bot = bot # assign the current bot to the persistence
|
||||
data_with_bot = {'current_bot': update.message}
|
||||
pickle_persistence.update_chat_data(12345, data_with_bot) # also calls BotPickler.dumps()
|
||||
|
||||
# Test that regular pickle load fails -
|
||||
err_msg = (
|
||||
"A load persistent id instruction was encountered,\nbut no persistent_load "
|
||||
"function was specified."
|
||||
)
|
||||
with pytest.raises(pickle.UnpicklingError, match=err_msg):
|
||||
with open('pickletest_chat_data', 'rb') as f:
|
||||
pickle.load(f)
|
||||
|
||||
# Test that our custom unpickler works as intended -- inserts the current bot
|
||||
# We have to create a new instance otherwise unpickling is skipped
|
||||
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
|
||||
pp.bot = bot # Set the bot
|
||||
assert pp.get_chat_data()[12345]['current_bot'].get_bot() is bot
|
||||
|
||||
# Now test that pickling of unknown bots in TelegramObjects will be replaced by None-
|
||||
assert not len(recwarn)
|
||||
data_with_bot['unknown_bot_in_user'] = User(1, 'Dev', False, bot=Bot('1234:abcd'))
|
||||
pickle_persistence.update_chat_data(12345, data_with_bot)
|
||||
assert len(recwarn) == 1
|
||||
assert recwarn[-1].category is PTBUserWarning
|
||||
assert str(recwarn[-1].message).startswith("Unknown bot instance found.")
|
||||
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
|
||||
pp.bot = bot
|
||||
assert pp.get_chat_data()[12345]['unknown_bot_in_user']._bot is None
|
||||
|
||||
def test_custom_pickler_unpickler_with_custom_objects(
|
||||
self, bot, pickle_persistence, good_pickle_files
|
||||
):
|
||||
dict_s = self.DictSub("private", 'normal', bot)
|
||||
slot_s = self.SlotsSub("new_var", 'private_var')
|
||||
regular = self.NormalClass(12)
|
||||
|
||||
pickle_persistence.bot = bot
|
||||
pickle_persistence.update_user_data(
|
||||
1232, {'sub_dict': dict_s, 'sub_slots': slot_s, 'r': regular}
|
||||
)
|
||||
pp = PicklePersistence("pickletest", single_file=False, on_flush=False)
|
||||
pp.bot = bot # Set the bot
|
||||
data = pp.get_user_data()[1232]
|
||||
sub_dict = data['sub_dict']
|
||||
sub_slots = data['sub_slots']
|
||||
sub_regular = data['r']
|
||||
assert sub_dict._bot is bot
|
||||
assert sub_dict.normal == dict_s.normal
|
||||
assert sub_dict._private == dict_s._private
|
||||
assert sub_slots.new_var == slot_s.new_var
|
||||
assert sub_slots._private == slot_s._private
|
||||
assert sub_slots._bot is None # We didn't set the bot, so it shouldn't have it here.
|
||||
assert sub_regular.my_var == regular.my_var
|
||||
|
||||
def test_custom_pickler_unpickler_with_handler_integration(
|
||||
self, bot, update, pickle_persistence, good_pickle_files, recwarn
|
||||
):
|
||||
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build()
|
||||
dp = u.dispatcher
|
||||
bot_id = None
|
||||
|
||||
def first(update, context):
|
||||
nonlocal bot_id
|
||||
bot_id = update.message.get_bot()
|
||||
# Test pickling a message object, which has the current bot
|
||||
context.user_data['msg'] = update.message
|
||||
# Test pickling a bot, which is not known. Directly serializing bots will fail.
|
||||
new_chat = Chat(1, 'private', bot=Bot('1234:abcd'))
|
||||
context.chat_data['unknown_bot_in_chat'] = new_chat
|
||||
|
||||
def second(_, context):
|
||||
msg = context.user_data['msg']
|
||||
assert bot_id is msg.get_bot() # Tests if the same bot is inserted by the unpickler
|
||||
new_none_bot = context.chat_data['unknown_bot_in_chat']._bot
|
||||
assert new_none_bot is None
|
||||
|
||||
h1 = MessageHandler(None, first)
|
||||
h2 = MessageHandler(None, second)
|
||||
dp.add_handler(h1)
|
||||
|
||||
assert not len(recwarn)
|
||||
dp.process_update(update)
|
||||
assert len(recwarn) == 1
|
||||
assert recwarn[-1].category is PTBUserWarning
|
||||
assert str(recwarn[-1].message).startswith("Unknown bot instance found.")
|
||||
|
||||
pickle_persistence_2 = PicklePersistence( # initialize a new persistence for unpickling
|
||||
filepath='pickletest',
|
||||
single_file=False,
|
||||
on_flush=False,
|
||||
)
|
||||
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_2).build()
|
||||
dp = u.dispatcher
|
||||
dp.add_handler(h2)
|
||||
dp.process_update(update)
|
||||
|
||||
def test_with_handler(self, bot, update, bot_data, pickle_persistence, good_pickle_files):
|
||||
u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build()
|
||||
dp = u.dispatcher
|
||||
bot.callback_data_cache.clear_callback_data()
|
||||
bot.callback_data_cache.clear_callback_queries()
|
||||
|
||||
def first(update, context):
|
||||
def first(_, context):
|
||||
if not context.user_data == {}:
|
||||
pytest.fail()
|
||||
if not context.chat_data == {}:
|
||||
|
@ -1653,7 +1482,7 @@ class TestPicklePersistence:
|
|||
context.bot_data['test1'] = 'test0'
|
||||
context.bot.callback_data_cache._callback_queries['test1'] = 'test0'
|
||||
|
||||
def second(update, context):
|
||||
def second(_, context):
|
||||
if not context.user_data['test1'] == 'test2':
|
||||
pytest.fail()
|
||||
if not context.chat_data['test3'] == 'test4':
|
||||
|
@ -1813,7 +1642,7 @@ class TestPicklePersistence:
|
|||
assert ch.conversations[ch._get_key(update)] == 0
|
||||
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||
|
||||
def test_with_nested_conversationHandler(
|
||||
def test_with_nested_conversation_handler(
|
||||
self, dp, update, good_pickle_files, pickle_persistence
|
||||
):
|
||||
dp.persistence = pickle_persistence
|
||||
|
|
|
@ -25,7 +25,6 @@ import inspect
|
|||
|
||||
included = { # These modules/classes intentionally have __dict__.
|
||||
'CallbackContext',
|
||||
'BasePersistence',
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,20 +16,28 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
|
||||
import datetime
|
||||
import json as json_lib
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
|
||||
try:
|
||||
import ujson
|
||||
except ImportError:
|
||||
ujson = None
|
||||
|
||||
from telegram import TelegramObject, Message, Chat, User
|
||||
from telegram import TelegramObject, Message, Chat, User, PhotoSize
|
||||
|
||||
|
||||
class TestTelegramObject:
|
||||
class Sub(TelegramObject):
|
||||
def __init__(self, private, normal, b):
|
||||
self._private = private
|
||||
self.normal = normal
|
||||
self._bot = b
|
||||
|
||||
def test_to_json_native(self, monkeypatch):
|
||||
if ujson:
|
||||
monkeypatch.setattr('ujson.dumps', json_lib.dumps)
|
||||
|
@ -145,3 +153,47 @@ class TestTelegramObject:
|
|||
assert message['from_user'] is user
|
||||
with pytest.raises(KeyError, match="Message don't have an attribute called `no_key`"):
|
||||
message['no_key']
|
||||
|
||||
def test_pickle(self, bot):
|
||||
chat = Chat(2, Chat.PRIVATE)
|
||||
user = User(3, 'first_name', False)
|
||||
date = datetime.datetime.now()
|
||||
photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot)
|
||||
msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo])
|
||||
|
||||
# Test pickling of TGObjects, we choose Message since it's contains the most subclasses.
|
||||
assert msg.get_bot()
|
||||
unpickled = pickle.loads(pickle.dumps(msg))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
unpickled.get_bot() # There should be no bot when we pickle TGObjects
|
||||
|
||||
assert unpickled.chat == chat
|
||||
assert unpickled.from_user == user
|
||||
assert unpickled.date == date
|
||||
assert unpickled.photo[0] == photo
|
||||
|
||||
def test_deepcopy_telegram_obj(self, bot):
|
||||
chat = Chat(2, Chat.PRIVATE)
|
||||
user = User(3, 'first_name', False)
|
||||
date = datetime.datetime.now()
|
||||
photo = PhotoSize('file_id', 'unique', 21, 21, bot=bot)
|
||||
msg = Message(1, date, chat, from_user=user, text='foobar', bot=bot, photo=[photo])
|
||||
|
||||
new_msg = deepcopy(msg)
|
||||
|
||||
# The same bot should be present when deepcopying.
|
||||
assert new_msg.get_bot() == bot and new_msg.get_bot() is bot
|
||||
|
||||
assert new_msg.date == date and new_msg.date is not date
|
||||
assert new_msg.chat == chat and new_msg.chat is not chat
|
||||
assert new_msg.from_user == user and new_msg.from_user is not user
|
||||
assert new_msg.photo[0] == photo and new_msg.photo[0] is not photo
|
||||
|
||||
def test_deepcopy_subclass_telegram_obj(self, bot):
|
||||
s = self.Sub("private", 'normal', bot)
|
||||
d = deepcopy(s)
|
||||
assert d is not s
|
||||
assert d._private == s._private # Can't test for identity since two equal strings is True
|
||||
assert d._bot == s._bot and d._bot is s._bot
|
||||
assert d.normal == s.normal
|
||||
|
|
Loading…
Reference in a new issue