mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-11-25 00:27:46 +01:00
Fix bugs in replace/insert_bot (#2218)
* Fix bugs in replace/insert_bot * Some tweaks
This commit is contained in:
parent
425716f966
commit
df6d5f0840
2 changed files with 124 additions and 17 deletions
|
@ -19,7 +19,6 @@
|
|||
"""This module contains the BasePersistence class."""
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from typing import Any, DefaultDict, Dict, Optional, Tuple, cast, ClassVar
|
||||
|
||||
|
@ -128,7 +127,7 @@ class BasePersistence(ABC):
|
|||
self.bot = bot
|
||||
|
||||
@classmethod
|
||||
def replace_bot(cls, obj: object) -> object: # pylint: disable=R0911
|
||||
def replace_bot(cls, obj: object) -> object:
|
||||
"""
|
||||
Replaces all instances of :class:`telegram.Bot` that occur within the passed object with
|
||||
:attr:`REPLACED_BOT`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
|
||||
|
@ -141,43 +140,72 @@ class BasePersistence(ABC):
|
|||
Returns:
|
||||
:obj:`obj`: Copy of the object with Bot instances replaced.
|
||||
"""
|
||||
return cls._replace_bot(obj, {})
|
||||
|
||||
@classmethod
|
||||
def _replace_bot(cls, obj: object, memo: Dict[int, Any]) -> object: # pylint: disable=R0911
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, Bot):
|
||||
memo[obj_id] = cls.REPLACED_BOT
|
||||
return cls.REPLACED_BOT
|
||||
if isinstance(obj, (list, tuple, set, frozenset)):
|
||||
return obj.__class__(cls.replace_bot(item) for item in obj)
|
||||
if isinstance(obj, (list, set)):
|
||||
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
|
||||
# over doesn't change its length during the iteration
|
||||
temp_iterable = obj.copy()
|
||||
new_iterable = obj.__class__(cls._replace_bot(item, memo) for item in temp_iterable)
|
||||
memo[obj_id] = new_iterable
|
||||
return new_iterable
|
||||
if isinstance(obj, (tuple, frozenset)):
|
||||
# tuples and frozensets are immutable so we don't need to worry about thread safety
|
||||
new_immutable = obj.__class__(cls._replace_bot(item, memo) for item in obj)
|
||||
memo[obj_id] = new_immutable
|
||||
return new_immutable
|
||||
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
memo[obj_id] = new_obj
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
'BasePersistence.replace_bot does not handle objects that can not be copied. See '
|
||||
'the docs of BasePersistence.replace_bot for more information.',
|
||||
RuntimeWarning,
|
||||
)
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
if isinstance(obj, (dict, defaultdict)):
|
||||
if isinstance(obj, dict):
|
||||
# We handle dicts via copy(obj) so we don't have to make a
|
||||
# difference between dict and defaultdict
|
||||
new_obj = cast(dict, new_obj)
|
||||
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
|
||||
# change during the iteration
|
||||
temp_dict = new_obj.copy()
|
||||
new_obj.clear()
|
||||
for k, val in obj.items():
|
||||
new_obj[cls.replace_bot(k)] = cls.replace_bot(val)
|
||||
for k, val in temp_dict.items():
|
||||
new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__dict__'):
|
||||
for attr_name, attr in new_obj.__dict__.items():
|
||||
setattr(new_obj, attr_name, cls.replace_bot(attr))
|
||||
setattr(new_obj, attr_name, cls._replace_bot(attr, memo))
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__slots__'):
|
||||
for attr_name in new_obj.__slots__:
|
||||
setattr(
|
||||
new_obj,
|
||||
attr_name,
|
||||
cls.replace_bot(cls.replace_bot(getattr(new_obj, attr_name))),
|
||||
cls._replace_bot(cls._replace_bot(getattr(new_obj, attr_name), memo), memo),
|
||||
)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
|
||||
return obj
|
||||
|
||||
def insert_bot(self, obj: object) -> object: # pylint: disable=R0911
|
||||
def insert_bot(self, obj: object) -> object:
|
||||
"""
|
||||
Replaces all instances of :attr:`REPLACED_BOT` that occur within the passed object with
|
||||
:attr:`bot`. Currently, this handles objects of type ``list``, ``tuple``, ``set``,
|
||||
|
@ -190,12 +218,31 @@ class BasePersistence(ABC):
|
|||
Returns:
|
||||
:obj:`obj`: Copy of the object with Bot instances inserted.
|
||||
"""
|
||||
return self._insert_bot(obj, {})
|
||||
|
||||
def _insert_bot(self, obj: object, memo: Dict[int, Any]) -> object: # pylint: disable=R0911
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, Bot):
|
||||
memo[obj_id] = self.bot
|
||||
return self.bot
|
||||
if isinstance(obj, str) and obj == self.REPLACED_BOT:
|
||||
memo[obj_id] = self.bot
|
||||
return self.bot
|
||||
if isinstance(obj, (list, tuple, set, frozenset)):
|
||||
return obj.__class__(self.insert_bot(item) for item in obj)
|
||||
if isinstance(obj, (list, set)):
|
||||
# We copy the iterable here for thread safety, i.e. make sure the object we iterate
|
||||
# over doesn't change its length during the iteration
|
||||
temp_iterable = obj.copy()
|
||||
new_iterable = obj.__class__(self._insert_bot(item, memo) for item in temp_iterable)
|
||||
memo[obj_id] = new_iterable
|
||||
return new_iterable
|
||||
if isinstance(obj, (tuple, frozenset)):
|
||||
# tuples and frozensets are immutable so we don't need to worry about thread safety
|
||||
new_immutable = obj.__class__(self._insert_bot(item, memo) for item in obj)
|
||||
memo[obj_id] = new_immutable
|
||||
return new_immutable
|
||||
|
||||
try:
|
||||
new_obj = copy(obj)
|
||||
|
@ -205,25 +252,34 @@ class BasePersistence(ABC):
|
|||
'the docs of BasePersistence.insert_bot for more information.',
|
||||
RuntimeWarning,
|
||||
)
|
||||
memo[obj_id] = obj
|
||||
return obj
|
||||
|
||||
if isinstance(obj, (dict, defaultdict)):
|
||||
if isinstance(obj, dict):
|
||||
# We handle dicts via copy(obj) so we don't have to make a
|
||||
# difference between dict and defaultdict
|
||||
new_obj = cast(dict, new_obj)
|
||||
# We can't iterate over obj.items() due to thread safety, i.e. the dicts length may
|
||||
# change during the iteration
|
||||
temp_dict = new_obj.copy()
|
||||
new_obj.clear()
|
||||
for k, val in obj.items():
|
||||
new_obj[self.insert_bot(k)] = self.insert_bot(val)
|
||||
for k, val in temp_dict.items():
|
||||
new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__dict__'):
|
||||
for attr_name, attr in new_obj.__dict__.items():
|
||||
setattr(new_obj, attr_name, self.insert_bot(attr))
|
||||
setattr(new_obj, attr_name, self._insert_bot(attr, memo))
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
if hasattr(obj, '__slots__'):
|
||||
for attr_name in obj.__slots__:
|
||||
setattr(
|
||||
new_obj,
|
||||
attr_name,
|
||||
self.insert_bot(self.insert_bot(getattr(new_obj, attr_name))),
|
||||
self._insert_bot(self._insert_bot(getattr(new_obj, attr_name), memo), memo),
|
||||
)
|
||||
memo[obj_id] = new_obj
|
||||
return new_obj
|
||||
|
||||
return obj
|
||||
|
|
|
@ -618,6 +618,57 @@ class TestBasePersistence:
|
|||
assert persistence.get_chat_data()[123][1].data == expected
|
||||
assert persistence.get_user_data()[123][1].data == expected
|
||||
|
||||
@pytest.mark.filterwarnings('ignore:BasePersistence')
|
||||
def test_replace_insert_bot_item_identity(self, bot, bot_persistence):
|
||||
persistence = bot_persistence
|
||||
persistence.set_bot(bot)
|
||||
|
||||
class CustomSlottedClass:
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self):
|
||||
self.value = 5
|
||||
|
||||
class CustomClass:
|
||||
pass
|
||||
|
||||
slot_object = CustomSlottedClass()
|
||||
dict_object = CustomClass()
|
||||
lock = Lock()
|
||||
list_ = [slot_object, dict_object, lock]
|
||||
tuple_ = (1, 2, 3)
|
||||
dict_ = {1: slot_object, 2: dict_object}
|
||||
|
||||
data = {
|
||||
'bot_1': bot,
|
||||
'bot_2': bot,
|
||||
'list_1': list_,
|
||||
'list_2': list_,
|
||||
'tuple_1': tuple_,
|
||||
'tuple_2': tuple_,
|
||||
'dict_1': dict_,
|
||||
'dict_2': dict_,
|
||||
}
|
||||
|
||||
def make_assertion(data_):
|
||||
return (
|
||||
data_['bot_1'] is data_['bot_2']
|
||||
and data_['list_1'] is data_['list_2']
|
||||
and data_['list_1'][0] is data_['list_2'][0]
|
||||
and data_['list_1'][1] is data_['list_2'][1]
|
||||
and data_['list_1'][2] is data_['list_2'][2]
|
||||
and data_['tuple_1'] is data_['tuple_2']
|
||||
and data_['dict_1'] is data_['dict_2']
|
||||
and data_['dict_1'][1] is data_['dict_2'][1]
|
||||
and data_['dict_1'][1] is data_['list_1'][0]
|
||||
and data_['dict_1'][2] is data_['list_1'][1]
|
||||
and data_['dict_1'][2] is data_['dict_2'][2]
|
||||
)
|
||||
|
||||
persistence.update_bot_data(data)
|
||||
assert make_assertion(persistence.bot_data)
|
||||
assert make_assertion(persistence.get_bot_data())
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def pickle_persistence():
|
||||
|
|
Loading…
Reference in a new issue