From 8d9bb26cca85efb69461f2bd2eeb701ee180c011 Mon Sep 17 00:00:00 2001 From: Bibo-Joshi Date: Sat, 14 Nov 2020 03:08:18 +0100 Subject: [PATCH] Improve Handling of Custom Objects in BasePersistence.insert/replace_bot (#2151) * Handle unpickable objects * Improve coverage * Add user warning * make comparison to REPLACED_BOT safe * make pre-commit happy * Shorten warning --- telegram/ext/basepersistence.py | 33 ++++++-- tests/test_persistence.py | 145 ++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 42 deletions(-) diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py index 1624e4ebd..d36e7893c 100644 --- a/telegram/ext/basepersistence.py +++ b/telegram/ext/basepersistence.py @@ -17,7 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the BasePersistence class.""" - +import warnings from abc import ABC, abstractmethod from collections import defaultdict from copy import copy @@ -128,12 +128,12 @@ class BasePersistence(ABC): self.bot = bot @classmethod - def replace_bot(cls, obj: object) -> object: + def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911 """ Replaces all instances of :class:`telegram.Bot` that occur within the passed object with :attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``, ``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or - ``__slot__`` attribute. + ``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`. Args: obj (:obj:`object`): The object @@ -146,7 +146,16 @@ class BasePersistence(ABC): if isinstance(obj, (list, tuple, set, frozenset)): return obj.__class__(cls.replace_bot(item) for item in obj) - new_obj = copy(obj) + try: + new_obj = copy(obj) + except Exception: + warnings.warn( + 'BasePersistence.replace_bot does not handle objects that can not be copied. See ' + 'the docs of BasePersistence.replace_bot for more information.', + RuntimeWarning, + ) + return obj + if isinstance(obj, (dict, defaultdict)): new_obj = cast(dict, new_obj) new_obj.clear() @@ -173,7 +182,7 @@ class BasePersistence(ABC): Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with :attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``, ``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or - ``__slot__`` attribute. + ``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`. Args: obj (:obj:`object`): The object @@ -183,12 +192,21 @@ class BasePersistence(ABC): """ if isinstance(obj, Bot): return self.bot - if obj == self.REPLACED_BOT: + if isinstance(obj, str) and obj == self.REPLACED_BOT: return self.bot if isinstance(obj, (list, tuple, set, frozenset)): return obj.__class__(self.insert_bot(item) for item in obj) - new_obj = copy(obj) + try: + new_obj = copy(obj) + except Exception: + warnings.warn( + 'BasePersistence.insert_bot does not handle objects that can not be copied. See ' + 'the docs of BasePersistence.insert_bot for more information.', + RuntimeWarning, + ) + return obj + if isinstance(obj, (dict, defaultdict)): new_obj = cast(dict, new_obj) new_obj.clear() @@ -207,6 +225,7 @@ class BasePersistence(ABC): self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))), ) return new_obj + return obj @abstractmethod diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 009dacfd5..212d1cea3 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import signal +from threading import Lock from telegram.utils.helpers import encode_conversations_to_json @@ -88,6 +89,42 @@ def base_persistence(): return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True) +@pytest.fixture(scope="function") +def bot_persistence(): + class BotPersistence(BasePersistence): + def __init__(self): + super().__init__() + self.bot_data = None + self.chat_data = defaultdict(dict) + self.user_data = defaultdict(dict) + + def get_bot_data(self): + return self.bot_data + + def get_chat_data(self): + return self.chat_data + + def get_user_data(self): + return self.user_data + + def get_conversations(self, name): + raise NotImplementedError + + def update_bot_data(self, data): + self.bot_data = data + + def update_chat_data(self, chat_id, data): + self.chat_data[chat_id] = data + + def update_user_data(self, user_id, data): + self.user_data[user_id] = data + + def update_conversation(self, name, key, new_state): + raise NotImplementedError + + return BotPersistence() + + @pytest.fixture(scope="function") def bot_data(): return {'test1': 'test2', 'test3': {'test4': 'test5'}} @@ -437,38 +474,7 @@ class TestBasePersistence: dp.process_update(MyUpdate()) assert 'An uncaught error was raised while processing the update' not in caplog.text - def test_bot_replace_insert_bot(self, bot): - class BotPersistence(BasePersistence): - def __init__(self): - super().__init__() - self.bot_data = None - self.chat_data = defaultdict(dict) - self.user_data = defaultdict(dict) - - def get_bot_data(self): - return self.bot_data - - def get_chat_data(self): - return self.chat_data - - def get_user_data(self): - return self.user_data - - def get_conversations(self, name): - raise NotImplementedError - - def update_bot_data(self, data): - self.bot_data = data - - def update_chat_data(self, chat_id, data): - self.chat_data[chat_id] = data - - def update_user_data(self, user_id, data): - self.user_data[user_id] = data - - def update_conversation(self, name, key, new_state): - raise NotImplementedError - + def test_bot_replace_insert_bot(self, bot, bot_persistence): class CustomSlottedClass: __slots__ = ('bot',) @@ -506,8 +512,6 @@ class TestBasePersistence: def __eq__(self, other): if isinstance(other, CustomClass): - # print(self.__dict__) - # print(other.__dict__) return ( self.bot is other.bot and self.slotted_object == other.slotted_object @@ -520,7 +524,7 @@ class TestBasePersistence: ) return False - persistence = BotPersistence() + persistence = bot_persistence persistence.set_bot(bot) cc = CustomClass() @@ -543,6 +547,77 @@ class TestBasePersistence: assert persistence.get_user_data()[123][1] == cc assert persistence.get_user_data()[123][1].bot is bot + def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn): + """Here check that unpickable objects are just returned verbatim.""" + persistence = bot_persistence + persistence.set_bot(bot) + + class CustomClass: + def __copy__(self): + raise TypeError('UnhandledException') + + lock = Lock() + + persistence.update_bot_data({1: lock}) + assert persistence.bot_data[1] is lock + persistence.update_chat_data(123, {1: lock}) + assert persistence.chat_data[123][1] is lock + persistence.update_user_data(123, {1: lock}) + assert persistence.user_data[123][1] is lock + + assert persistence.get_bot_data()[1] is lock + assert persistence.get_chat_data()[123][1] is lock + assert persistence.get_user_data()[123][1] is lock + + cc = CustomClass() + + persistence.update_bot_data({1: cc}) + assert persistence.bot_data[1] is cc + persistence.update_chat_data(123, {1: cc}) + assert persistence.chat_data[123][1] is cc + persistence.update_user_data(123, {1: cc}) + assert persistence.user_data[123][1] is cc + + assert persistence.get_bot_data()[1] is cc + assert persistence.get_chat_data()[123][1] is cc + assert persistence.get_user_data()[123][1] is cc + + assert len(recwarn) == 2 + assert str(recwarn[0].message).startswith( + "BasePersistence.replace_bot does not handle objects that can not be copied." + ) + assert str(recwarn[1].message).startswith( + "BasePersistence.insert_bot does not handle objects that can not be copied." + ) + + def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence): + """Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems.""" + persistence = bot_persistence + persistence.set_bot(bot) + + class CustomClass: + def __init__(self, data): + self.data = data + + def __eq__(self, other): + raise RuntimeError("Can't be compared") + + cc = CustomClass({1: bot, 2: 'foo'}) + expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'} + + persistence.update_bot_data({1: cc}) + assert persistence.bot_data[1].data == expected + persistence.update_chat_data(123, {1: cc}) + assert persistence.chat_data[123][1].data == expected + persistence.update_user_data(123, {1: cc}) + assert persistence.user_data[123][1].data == expected + + expected = {1: bot, 2: 'foo'} + + assert persistence.get_bot_data()[1].data == expected + assert persistence.get_chat_data()[123][1].data == expected + assert persistence.get_user_data()[123][1].data == expected + @pytest.fixture(scope='function') def pickle_persistence():