mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-10-24 01:46:22 +02:00
Improve Handling of Custom Objects in BasePersistence.insert/replace_bot (#2151)
* Handle unpickable objects * Improve coverage * Add user warning * make comparison to REPLACED_BOT safe * make pre-commit happy * Shorten warning
This commit is contained in:
parent
d1438a9b23
commit
8d9bb26cca
2 changed files with 136 additions and 42 deletions
|
@ -17,7 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
"""This module contains the BasePersistence class."""
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
|
@ -128,12 +128,12 @@ class BasePersistence(ABC):
|
|||
self.bot = bot
|
||||
|
||||
@classmethod
|
||||
def replace_bot(cls, obj: object) -> object:
|
||||
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911
|
||||
"""
|
||||
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
|
||||
:attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
|
||||
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
|
||||
``__slot__`` attribute.
|
||||
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.
|
||||
|
||||
Args:
|
||||
obj (:obj:`object`): The object
|
||||
|
@ -146,7 +146,16 @@ class BasePersistence(ABC):
|
|||
if isinstance(obj, (list, tuple, set, frozenset)):
|
||||
return obj.__class__(cls.replace_bot(item) for item in obj)
|
||||
|
||||
new_obj = copy(obj)
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
'BasePersistence.replace_bot does not handle objects that can not be copied. See '
|
||||
'the docs of BasePersistence.replace_bot for more information.',
|
||||
RuntimeWarning,
|
||||
)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, (dict, defaultdict)):
|
||||
new_obj = cast(dict, new_obj)
|
||||
new_obj.clear()
|
||||
|
@ -173,7 +182,7 @@ class BasePersistence(ABC):
|
|||
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
|
||||
:attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
|
||||
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
|
||||
``__slot__`` attribute.
|
||||
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.
|
||||
|
||||
Args:
|
||||
obj (:obj:`object`): The object
|
||||
|
@ -183,12 +192,21 @@ class BasePersistence(ABC):
|
|||
"""
|
||||
if isinstance(obj, Bot):
|
||||
return self.bot
|
||||
if obj == self.REPLACED_BOT:
|
||||
if isinstance(obj, str) and obj == self.REPLACED_BOT:
|
||||
return self.bot
|
||||
if isinstance(obj, (list, tuple, set, frozenset)):
|
||||
return obj.__class__(self.insert_bot(item) for item in obj)
|
||||
|
||||
new_obj = copy(obj)
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
'BasePersistence.insert_bot does not handle objects that can not be copied. See '
|
||||
'the docs of BasePersistence.insert_bot for more information.',
|
||||
RuntimeWarning,
|
||||
)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, (dict, defaultdict)):
|
||||
new_obj = cast(dict, new_obj)
|
||||
new_obj.clear()
|
||||
|
@ -207,6 +225,7 @@ class BasePersistence(ABC):
|
|||
self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))),
|
||||
)
|
||||
return new_obj
|
||||
|
||||
return obj
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
import signal
|
||||
from threading import Lock
|
||||
|
||||
from telegram.utils.helpers import encode_conversations_to_json
|
||||
|
||||
|
@ -88,6 +89,42 @@ def base_persistence():
|
|||
return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def bot_persistence():
|
||||
class BotPersistence(BasePersistence):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_data = None
|
||||
self.chat_data = defaultdict(dict)
|
||||
self.user_data = defaultdict(dict)
|
||||
|
||||
def get_bot_data(self):
|
||||
return self.bot_data
|
||||
|
||||
def get_chat_data(self):
|
||||
return self.chat_data
|
||||
|
||||
def get_user_data(self):
|
||||
return self.user_data
|
||||
|
||||
def get_conversations(self, name):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_bot_data(self, data):
|
||||
self.bot_data = data
|
||||
|
||||
def update_chat_data(self, chat_id, data):
|
||||
self.chat_data[chat_id] = data
|
||||
|
||||
def update_user_data(self, user_id, data):
|
||||
self.user_data[user_id] = data
|
||||
|
||||
def update_conversation(self, name, key, new_state):
|
||||
raise NotImplementedError
|
||||
|
||||
return BotPersistence()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def bot_data():
|
||||
return {'test1': 'test2', 'test3': {'test4': 'test5'}}
|
||||
|
@ -437,38 +474,7 @@ class TestBasePersistence:
|
|||
dp.process_update(MyUpdate())
|
||||
assert 'An uncaught error was raised while processing the update' not in caplog.text
|
||||
|
||||
def test_bot_replace_insert_bot(self, bot):
|
||||
class BotPersistence(BasePersistence):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_data = None
|
||||
self.chat_data = defaultdict(dict)
|
||||
self.user_data = defaultdict(dict)
|
||||
|
||||
def get_bot_data(self):
|
||||
return self.bot_data
|
||||
|
||||
def get_chat_data(self):
|
||||
return self.chat_data
|
||||
|
||||
def get_user_data(self):
|
||||
return self.user_data
|
||||
|
||||
def get_conversations(self, name):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_bot_data(self, data):
|
||||
self.bot_data = data
|
||||
|
||||
def update_chat_data(self, chat_id, data):
|
||||
self.chat_data[chat_id] = data
|
||||
|
||||
def update_user_data(self, user_id, data):
|
||||
self.user_data[user_id] = data
|
||||
|
||||
def update_conversation(self, name, key, new_state):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_bot_replace_insert_bot(self, bot, bot_persistence):
|
||||
class CustomSlottedClass:
|
||||
__slots__ = ('bot',)
|
||||
|
||||
|
@ -506,8 +512,6 @@ class TestBasePersistence:
|
|||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CustomClass):
|
||||
# print(self.__dict__)
|
||||
# print(other.__dict__)
|
||||
return (
|
||||
self.bot is other.bot
|
||||
and self.slotted_object == other.slotted_object
|
||||
|
@ -520,7 +524,7 @@ class TestBasePersistence:
|
|||
)
|
||||
return False
|
||||
|
||||
persistence = BotPersistence()
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
cc = CustomClass()
|
||||
|
||||
|
@ -543,6 +547,77 @@ class TestBasePersistence:
|
|||
assert persistence.get_user_data()[123][1] == cc
|
||||
assert persistence.get_user_data()[123][1].bot is bot
|
||||
|
||||
def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn):
|
||||
"""Here check that unpickable objects are just returned verbatim."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomClass:
|
||||
def __copy__(self):
|
||||
raise TypeError('UnhandledException')
|
||||
|
||||
lock = Lock()
|
||||
|
||||
persistence.update_bot_data({1: lock})
|
||||
assert persistence.bot_data[1] is lock
|
||||
persistence.update_chat_data(123, {1: lock})
|
||||
assert persistence.chat_data[123][1] is lock
|
||||
persistence.update_user_data(123, {1: lock})
|
||||
assert persistence.user_data[123][1] is lock
|
||||
|
||||
assert persistence.get_bot_data()[1] is lock
|
||||
assert persistence.get_chat_data()[123][1] is lock
|
||||
assert persistence.get_user_data()[123][1] is lock
|
||||
|
||||
cc = CustomClass()
|
||||
|
||||
persistence.update_bot_data({1: cc})
|
||||
assert persistence.bot_data[1] is cc
|
||||
persistence.update_chat_data(123, {1: cc})
|
||||
assert persistence.chat_data[123][1] is cc
|
||||
persistence.update_user_data(123, {1: cc})
|
||||
assert persistence.user_data[123][1] is cc
|
||||
|
||||
assert persistence.get_bot_data()[1] is cc
|
||||
assert persistence.get_chat_data()[123][1] is cc
|
||||
assert persistence.get_user_data()[123][1] is cc
|
||||
|
||||
assert len(recwarn) == 2
|
||||
assert str(recwarn[0].message).startswith(
|
||||
"BasePersistence.replace_bot does not handle objects that can not be copied."
|
||||
)
|
||||
assert str(recwarn[1].message).startswith(
|
||||
"BasePersistence.insert_bot does not handle objects that can not be copied."
|
||||
)
|
||||
|
||||
def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence):
|
||||
"""Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems."""
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomClass:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __eq__(self, other):
|
||||
raise RuntimeError("Can't be compared")
|
||||
|
||||
cc = CustomClass({1: bot, 2: 'foo'})
|
||||
expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'}
|
||||
|
||||
persistence.update_bot_data({1: cc})
|
||||
assert persistence.bot_data[1].data == expected
|
||||
persistence.update_chat_data(123, {1: cc})
|
||||
assert persistence.chat_data[123][1].data == expected
|
||||
persistence.update_user_data(123, {1: cc})
|
||||
assert persistence.user_data[123][1].data == expected
|
||||
|
||||
expected = {1: bot, 2: 'foo'}
|
||||
|
||||
assert persistence.get_bot_data()[1].data == expected
|
||||
assert persistence.get_chat_data()[123][1].data == expected
|
||||
assert persistence.get_user_data()[123][1].data == expected
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def pickle_persistence():
|
||||
|
|
Loading…
Reference in a new issue