Fix bugs in replace/insert_bot (#2218)

* Fix bugs in replace/insert_bot

* Some tweaks
This commit is contained in:
Bibo-Joshi 2020-11-22 11:08:46 +01:00 committed by GitHub
parent 425716f966
commit df6d5f0840
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 17 deletions

View file

@ -19,7 +19,6 @@
"""This module contains the BasePersistence class.""" """This module contains the BasePersistence class."""
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
from copy import copy from copy import copy
from typing import Any, DefaultDict, Dict, Optional, Tuple, cast, ClassVar from typing import Any, DefaultDict, Dict, Optional, Tuple, cast, ClassVar
@ -128,7 +127,7 @@ class BasePersistence(ABC):
self.bot = bot self.bot = bot
@classmethod @classmethod
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911 def replace_bot(cls, obj: object) -> object:
""" """
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
:attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``, :attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
@ -141,43 +140,72 @@ class BasePersistence(ABC):
Returns: Returns:
:obj:`obj`: Copy of the object with Bot instances replaced. :obj:`obj`: Copy of the object with Bot instances replaced.
""" """
return cls._replace_bot(obj, {})
@classmethod
def _replace_bot(cls, obj: object, memo: Dict[int, Any]) -> object: # pylint: disable=R0911
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, Bot): if isinstance(obj, Bot):
memo[obj_id] = cls.REPLACED_BOT
return cls.REPLACED_BOT return cls.REPLACED_BOT
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, (list, set)):
return obj.__class__(cls.replace_bot(item) for item in obj) # We copy the iterable here for thread safety, i.e. make sure the object we iterate
# over doesn't change its length during the iteration
temp_iterable = obj.copy()
new_iterable = obj.__class__(cls._replace_bot(item, memo) for item in temp_iterable)
memo[obj_id] = new_iterable
return new_iterable
if isinstance(obj, (tuple, frozenset)):
# tuples and frozensets are immutable so we don't need to worry about thread safety
new_immutable = obj.__class__(cls._replace_bot(item, memo) for item in obj)
memo[obj_id] = new_immutable
return new_immutable
try: try:
new_obj = copy(obj) new_obj = copy(obj)
memo[obj_id] = new_obj
except Exception: except Exception:
warnings.warn( warnings.warn(
'BasePersistence.replace_bot does not handle objects that can not be copied. See ' 'BasePersistence.replace_bot does not handle objects that can not be copied. See '
'the docs of BasePersistence.replace_bot for more information.', 'the docs of BasePersistence.replace_bot for more information.',
RuntimeWarning, RuntimeWarning,
) )
memo[obj_id] = obj
return obj return obj
if isinstance(obj, (dict, defaultdict)): if isinstance(obj, dict):
# We handle dicts via copy(obj) so we don't have to make a
# difference between dict and defaultdict
new_obj = cast(dict, new_obj) new_obj = cast(dict, new_obj)
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
# change during the iteration
temp_dict = new_obj.copy()
new_obj.clear() new_obj.clear()
for k, val in obj.items(): for k, val in temp_dict.items():
new_obj[cls.replace_bot(k)] = cls.replace_bot(val) new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo)
memo[obj_id] = new_obj
return new_obj return new_obj
if hasattr(obj, '__dict__'): if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items(): for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, cls.replace_bot(attr)) setattr(new_obj, attr_name, cls._replace_bot(attr, memo))
memo[obj_id] = new_obj
return new_obj return new_obj
if hasattr(obj, '__slots__'): if hasattr(obj, '__slots__'):
for attr_name in new_obj.__slots__: for attr_name in new_obj.__slots__:
setattr( setattr(
new_obj, new_obj,
attr_name, attr_name,
cls.replace_bot(cls.replace_bot(getattr(new_obj, attr_name))), cls._replace_bot(cls._replace_bot(getattr(new_obj, attr_name), memo), memo),
) )
memo[obj_id] = new_obj
return new_obj return new_obj
return obj return obj
def insert_bot(self, obj: object) -> object: # pylint: disable=R0911 def insert_bot(self, obj: object) -> object:
""" """
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
:attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``, :attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
@ -190,12 +218,31 @@ class BasePersistence(ABC):
Returns: Returns:
:obj:`obj`: Copy of the object with Bot instances inserted. :obj:`obj`: Copy of the object with Bot instances inserted.
""" """
return self._insert_bot(obj, {})
def _insert_bot(self, obj: object, memo: Dict[int, Any]) -> object: # pylint: disable=R0911
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, Bot): if isinstance(obj, Bot):
memo[obj_id] = self.bot
return self.bot return self.bot
if isinstance(obj, str) and obj == self.REPLACED_BOT: if isinstance(obj, str) and obj == self.REPLACED_BOT:
memo[obj_id] = self.bot
return self.bot return self.bot
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, (list, set)):
return obj.__class__(self.insert_bot(item) for item in obj) # We copy the iterable here for thread safety, i.e. make sure the object we iterate
# over doesn't change its length during the iteration
temp_iterable = obj.copy()
new_iterable = obj.__class__(self._insert_bot(item, memo) for item in temp_iterable)
memo[obj_id] = new_iterable
return new_iterable
if isinstance(obj, (tuple, frozenset)):
# tuples and frozensets are immutable so we don't need to worry about thread safety
new_immutable = obj.__class__(self._insert_bot(item, memo) for item in obj)
memo[obj_id] = new_immutable
return new_immutable
try: try:
new_obj = copy(obj) new_obj = copy(obj)
@ -205,25 +252,34 @@ class BasePersistence(ABC):
'the docs of BasePersistence.insert_bot for more information.', 'the docs of BasePersistence.insert_bot for more information.',
RuntimeWarning, RuntimeWarning,
) )
memo[obj_id] = obj
return obj return obj
if isinstance(obj, (dict, defaultdict)): if isinstance(obj, dict):
# We handle dicts via copy(obj) so we don't have to make a
# difference between dict and defaultdict
new_obj = cast(dict, new_obj) new_obj = cast(dict, new_obj)
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
# change during the iteration
temp_dict = new_obj.copy()
new_obj.clear() new_obj.clear()
for k, val in obj.items(): for k, val in temp_dict.items():
new_obj[self.insert_bot(k)] = self.insert_bot(val) new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo)
memo[obj_id] = new_obj
return new_obj return new_obj
if hasattr(obj, '__dict__'): if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items(): for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, self.insert_bot(attr)) setattr(new_obj, attr_name, self._insert_bot(attr, memo))
memo[obj_id] = new_obj
return new_obj return new_obj
if hasattr(obj, '__slots__'): if hasattr(obj, '__slots__'):
for attr_name in obj.__slots__: for attr_name in obj.__slots__:
setattr( setattr(
new_obj, new_obj,
attr_name, attr_name,
self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))), self._insert_bot(self._insert_bot(getattr(new_obj, attr_name), memo), memo),
) )
memo[obj_id] = new_obj
return new_obj return new_obj
return obj return obj

View file

@ -618,6 +618,57 @@ class TestBasePersistence:
assert persistence.get_chat_data()[123][1].data == expected assert persistence.get_chat_data()[123][1].data == expected
assert persistence.get_user_data()[123][1].data == expected assert persistence.get_user_data()[123][1].data == expected
@pytest.mark.filterwarnings('ignore:BasePersistence')
def test_replace_insert_bot_item_identity(self, bot, bot_persistence):
persistence = bot_persistence
persistence.set_bot(bot)
class CustomSlottedClass:
__slots__ = ('value',)
def __init__(self):
self.value = 5
class CustomClass:
pass
slot_object = CustomSlottedClass()
dict_object = CustomClass()
lock = Lock()
list_ = [slot_object, dict_object, lock]
tuple_ = (1, 2, 3)
dict_ = {1: slot_object, 2: dict_object}
data = {
'bot_1': bot,
'bot_2': bot,
'list_1': list_,
'list_2': list_,
'tuple_1': tuple_,
'tuple_2': tuple_,
'dict_1': dict_,
'dict_2': dict_,
}
def make_assertion(data_):
return (
data_['bot_1'] is data_['bot_2']
and data_['list_1'] is data_['list_2']
and data_['list_1'][0] is data_['list_2'][0]
and data_['list_1'][1] is data_['list_2'][1]
and data_['list_1'][2] is data_['list_2'][2]
and data_['tuple_1'] is data_['tuple_2']
and data_['dict_1'] is data_['dict_2']
and data_['dict_1'][1] is data_['dict_2'][1]
and data_['dict_1'][1] is data_['list_1'][0]
and data_['dict_1'][2] is data_['list_1'][1]
and data_['dict_1'][2] is data_['dict_2'][2]
)
persistence.update_bot_data(data)
assert make_assertion(persistence.bot_data)
assert make_assertion(persistence.get_bot_data())
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def pickle_persistence(): def pickle_persistence():