Improve Handling of Custom Objects in BasePersistence.insert/replace_bot (#2151)

* Handle unpickable objects

* Improve coverage

* Add user warning

* make comparison to REPLACED_BOT safe

* make pre-commit happy

* Shorten warning
This commit is contained in:
Bibo-Joshi 2020-11-14 03:08:18 +01:00 committed by GitHub
parent d1438a9b23
commit 8d9bb26cca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 136 additions and 42 deletions

View file

@ -17,7 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import copy
@ -128,12 +128,12 @@ class BasePersistence(ABC):
self.bot = bot
@classmethod
def replace_bot(cls, obj: object) -> object:
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911
"""
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
:attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
``__slot__`` attribute.
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.
Args:
obj (:obj:`object`): The object
@ -146,7 +146,16 @@ class BasePersistence(ABC):
if isinstance(obj, (list, tuple, set, frozenset)):
return obj.__class__(cls.replace_bot(item) for item in obj)
new_obj = copy(obj)
try:
new_obj = copy(obj)
except Exception:
warnings.warn(
'BasePersistence.replace_bot does not handle objects that can not be copied. See '
'the docs of BasePersistence.replace_bot for more information.',
RuntimeWarning,
)
return obj
if isinstance(obj, (dict, defaultdict)):
new_obj = cast(dict, new_obj)
new_obj.clear()
@ -173,7 +182,7 @@ class BasePersistence(ABC):
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
:attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
``frozenset``, ``dict``, ``defaultdict`` and objects that have a ``__dict__`` or
``__slot__`` attribute.
``__slot__`` attribute, excluding objects that can't be copied with `copy.copy`.
Args:
obj (:obj:`object`): The object
@ -183,12 +192,21 @@ class BasePersistence(ABC):
"""
if isinstance(obj, Bot):
return self.bot
if obj == self.REPLACED_BOT:
if isinstance(obj, str) and obj == self.REPLACED_BOT:
return self.bot
if isinstance(obj, (list, tuple, set, frozenset)):
return obj.__class__(self.insert_bot(item) for item in obj)
new_obj = copy(obj)
try:
new_obj = copy(obj)
except Exception:
warnings.warn(
'BasePersistence.insert_bot does not handle objects that can not be copied. See '
'the docs of BasePersistence.insert_bot for more information.',
RuntimeWarning,
)
return obj
if isinstance(obj, (dict, defaultdict)):
new_obj = cast(dict, new_obj)
new_obj.clear()
@ -207,6 +225,7 @@ class BasePersistence(ABC):
self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))),
)
return new_obj
return obj
@abstractmethod

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import signal
from threading import Lock
from telegram.utils.helpers import encode_conversations_to_json
@ -88,6 +89,42 @@ def base_persistence():
return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True)
@pytest.fixture(scope="function")
def bot_persistence():
class BotPersistence(BasePersistence):
def __init__(self):
super().__init__()
self.bot_data = None
self.chat_data = defaultdict(dict)
self.user_data = defaultdict(dict)
def get_bot_data(self):
return self.bot_data
def get_chat_data(self):
return self.chat_data
def get_user_data(self):
return self.user_data
def get_conversations(self, name):
raise NotImplementedError
def update_bot_data(self, data):
self.bot_data = data
def update_chat_data(self, chat_id, data):
self.chat_data[chat_id] = data
def update_user_data(self, user_id, data):
self.user_data[user_id] = data
def update_conversation(self, name, key, new_state):
raise NotImplementedError
return BotPersistence()
@pytest.fixture(scope="function")
def bot_data():
return {'test1': 'test2', 'test3': {'test4': 'test5'}}
@ -437,38 +474,7 @@ class TestBasePersistence:
dp.process_update(MyUpdate())
assert 'An uncaught error was raised while processing the update' not in caplog.text
def test_bot_replace_insert_bot(self, bot):
class BotPersistence(BasePersistence):
def __init__(self):
super().__init__()
self.bot_data = None
self.chat_data = defaultdict(dict)
self.user_data = defaultdict(dict)
def get_bot_data(self):
return self.bot_data
def get_chat_data(self):
return self.chat_data
def get_user_data(self):
return self.user_data
def get_conversations(self, name):
raise NotImplementedError
def update_bot_data(self, data):
self.bot_data = data
def update_chat_data(self, chat_id, data):
self.chat_data[chat_id] = data
def update_user_data(self, user_id, data):
self.user_data[user_id] = data
def update_conversation(self, name, key, new_state):
raise NotImplementedError
def test_bot_replace_insert_bot(self, bot, bot_persistence):
class CustomSlottedClass:
__slots__ = ('bot',)
@ -506,8 +512,6 @@ class TestBasePersistence:
def __eq__(self, other):
if isinstance(other, CustomClass):
# print(self.__dict__)
# print(other.__dict__)
return (
self.bot is other.bot
and self.slotted_object == other.slotted_object
@ -520,7 +524,7 @@ class TestBasePersistence:
)
return False
persistence = BotPersistence()
persistence = bot_persistence
persistence.set_bot(bot)
cc = CustomClass()
@ -543,6 +547,77 @@ class TestBasePersistence:
assert persistence.get_user_data()[123][1] == cc
assert persistence.get_user_data()[123][1].bot is bot
def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn):
"""Here check that unpickable objects are just returned verbatim."""
persistence = bot_persistence
persistence.set_bot(bot)
class CustomClass:
def __copy__(self):
raise TypeError('UnhandledException')
lock = Lock()
persistence.update_bot_data({1: lock})
assert persistence.bot_data[1] is lock
persistence.update_chat_data(123, {1: lock})
assert persistence.chat_data[123][1] is lock
persistence.update_user_data(123, {1: lock})
assert persistence.user_data[123][1] is lock
assert persistence.get_bot_data()[1] is lock
assert persistence.get_chat_data()[123][1] is lock
assert persistence.get_user_data()[123][1] is lock
cc = CustomClass()
persistence.update_bot_data({1: cc})
assert persistence.bot_data[1] is cc
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1] is cc
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1] is cc
assert persistence.get_bot_data()[1] is cc
assert persistence.get_chat_data()[123][1] is cc
assert persistence.get_user_data()[123][1] is cc
assert len(recwarn) == 2
assert str(recwarn[0].message).startswith(
"BasePersistence.replace_bot does not handle objects that can not be copied."
)
assert str(recwarn[1].message).startswith(
"BasePersistence.insert_bot does not handle objects that can not be copied."
)
def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence):
"""Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems."""
persistence = bot_persistence
persistence.set_bot(bot)
class CustomClass:
def __init__(self, data):
self.data = data
def __eq__(self, other):
raise RuntimeError("Can't be compared")
cc = CustomClass({1: bot, 2: 'foo'})
expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'}
persistence.update_bot_data({1: cc})
assert persistence.bot_data[1].data == expected
persistence.update_chat_data(123, {1: cc})
assert persistence.chat_data[123][1].data == expected
persistence.update_user_data(123, {1: cc})
assert persistence.user_data[123][1].data == expected
expected = {1: bot, 2: 'foo'}
assert persistence.get_bot_data()[1].data == expected
assert persistence.get_chat_data()[123][1].data == expected
assert persistence.get_user_data()[123][1].data == expected
@pytest.fixture(scope='function')
def pickle_persistence():