Fix Bug in BasePersistence.insert/replace_bot for Objects with __dict__ not in __slots__ (#2603)

* More special cases with slots

* Fix failing tests
This commit is contained in:
Bibo-Joshi 2021-07-24 17:17:25 +02:00 committed by GitHub
parent bcec6f03cb
commit 1fdaaac809
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 15 deletions

View file

@ -277,8 +277,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
new_obj[cls._replace_bot(k, memo)] = cls._replace_bot(val, memo)
memo[obj_id] = new_obj
return new_obj
# if '__dict__' in obj.__slots__, we already cover this here, that's why the
# __dict__ case comes below
try:
if hasattr(obj, '__slots__'):
for attr_name in new_obj.__slots__:
@ -289,8 +287,11 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
cls._replace_bot(getattr(new_obj, attr_name), memo), memo
),
)
memo[obj_id] = new_obj
return new_obj
if '__dict__' in obj.__slots__:
# In this case, we have already covered the case that obj has __dict__
# Note that obj may have a __dict__ even if it's not in __slots__!
memo[obj_id] = new_obj
return new_obj
if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, cls._replace_bot(attr, memo))
@ -302,9 +303,8 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
f'See the docs of BasePersistence.replace_bot for more information.',
RuntimeWarning,
)
memo[obj_id] = obj
return obj
memo[obj_id] = obj
return obj
def insert_bot(self, obj: object) -> object:
@ -379,8 +379,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
new_obj[self._insert_bot(k, memo)] = self._insert_bot(val, memo)
memo[obj_id] = new_obj
return new_obj
# if '__dict__' in obj.__slots__, we already cover this here, that's why the
# __dict__ case comes below
try:
if hasattr(obj, '__slots__'):
for attr_name in obj.__slots__:
@ -391,8 +389,11 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
self._insert_bot(getattr(new_obj, attr_name), memo), memo
),
)
memo[obj_id] = new_obj
return new_obj
if '__dict__' in obj.__slots__:
# In this case, we have already covered the case that obj has __dict__
# Note that obj may have a __dict__ even if it's not in __slots__!
memo[obj_id] = new_obj
return new_obj
if hasattr(obj, '__dict__'):
for attr_name, attr in new_obj.__dict__.items():
setattr(new_obj, attr_name, self._insert_bot(attr, memo))
@ -404,9 +405,8 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
f'See the docs of BasePersistence.insert_bot for more information.',
RuntimeWarning,
)
memo[obj_id] = obj
return obj
memo[obj_id] = obj
return obj
@abstractmethod

View file

@ -32,6 +32,7 @@ import logging
import os
import pickle
from collections import defaultdict
from collections.abc import Container
from time import sleep
from sys import version_info as py_ver
@ -566,17 +567,32 @@ class TestBasePersistence:
def __init__(self):
self.bot = bot
self.not_in_dict = bot
self.not_in_slots = bot
def __eq__(self, other):
if isinstance(other, CustomSlottedClass):
return self.bot is other.bot and self.not_in_dict is other.not_in_dict
return self.bot is other.bot and self.not_in_slots is other.not_in_slots
return False
class DictNotInSlots(Container):
"""This classes parent has slots, but __dict__ is not in those slots."""
def __init__(self):
self.bot = bot
def __contains__(self, item):
return True
def __eq__(self, other):
if isinstance(other, DictNotInSlots):
return self.bot is other.bot
return False
class CustomClass:
def __init__(self):
self.bot = bot
self.slotted_object = CustomSlottedClass()
self.dict_not_in_slots_object = DictNotInSlots()
self.list_ = [1, 2, bot]
self.tuple_ = tuple(self.list_)
self.set_ = set(self.list_)
@ -589,7 +605,8 @@ class TestBasePersistence:
cc = CustomClass()
cc.bot = BasePersistence.REPLACED_BOT
cc.slotted_object.bot = BasePersistence.REPLACED_BOT
cc.slotted_object.not_in_dict = BasePersistence.REPLACED_BOT
cc.slotted_object.not_in_slots = BasePersistence.REPLACED_BOT
cc.dict_not_in_slots_object.bot = BasePersistence.REPLACED_BOT
cc.list_ = [1, 2, BasePersistence.REPLACED_BOT]
cc.tuple_ = tuple(cc.list_)
cc.set_ = set(cc.list_)
@ -603,6 +620,7 @@ class TestBasePersistence:
return (
self.bot is other.bot
and self.slotted_object == other.slotted_object
and self.dict_not_in_slots_object == other.dict_not_in_slots_object
and self.list_ == other.list_
and self.tuple_ == other.tuple_
and self.set_ == other.set_