Add Method drop_chat/user_data to Dispatcher and Persistence (#2852)

This commit is contained in:
Harshil 2022-01-19 20:27:02 +04:00 committed by Hinrich Mahler
parent e442782d8f
commit 0ccd7d40ac
7 changed files with 336 additions and 118 deletions

View file

@ -457,7 +457,7 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
Returns: Returns:
Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \ Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \
Dict[:obj:`str`, :obj:`Any`]]], Dict[:obj:`str`, :obj:`str`]]: Dict[:obj:`str`, :obj:`Any`]]], Dict[:obj:`str`, :obj:`str`]]]:
The restored meta data or :obj:`None`, if no data was stored. The restored meta data or :obj:`None`, if no data was stored.
""" """
@ -520,6 +520,44 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
The :attr:`telegram.ext.Dispatcher.bot_data`. The :attr:`telegram.ext.Dispatcher.bot_data`.
""" """
@abstractmethod
def update_callback_data(self, data: CDCData) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
.. versionadded:: 13.6
.. versionchanged:: 14.0
Changed this method into an ``@abstractmethod``.
Args:
data (Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \
Dict[:obj:`str`, :obj:`Any`]]], Dict[:obj:`str`, :obj:`str`]]]):
The relevant data to restore :class:`telegram.ext.CallbackDataCache`.
"""
@abstractmethod
def drop_chat_data(self, chat_id: int) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher`, when using
:meth:`~telegram.ext.Dispatcher.drop_chat_data`.
.. versionadded:: 14.0
Args:
chat_id (:obj:`int`): The chat id to delete from the persistence.
"""
@abstractmethod
def drop_user_data(self, user_id: int) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher`, when using
:meth:`~telegram.ext.Dispatcher.drop_user_data`.
.. versionadded:: 14.0
Args:
user_id (:obj:`int`): The user id to delete from the persistence.
"""
@abstractmethod @abstractmethod
def refresh_user_data(self, user_id: int, user_data: UD) -> None: def refresh_user_data(self, user_id: int, user_data: UD) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` before passing the """Will be called by the :class:`telegram.ext.Dispatcher` before passing the
@ -570,22 +608,6 @@ class BasePersistence(Generic[UD, CD, BD], ABC):
The ``bot_data``. The ``bot_data``.
""" """
@abstractmethod
def update_callback_data(self, data: CDCData) -> None:
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
.. versionadded:: 13.6
.. versionchanged:: 14.0
Changed this method into an ``@abstractmethod``.
Args:
data (Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \
Dict[:obj:`str`, :obj:`Any`]]], Dict[:obj:`str`, :obj:`str`]]):
The relevant data to restore :class:`telegram.ext.CallbackDataCache`.
"""
@abstractmethod @abstractmethod
def flush(self) -> None: def flush(self) -> None:
"""Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the """Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the

View file

@ -361,6 +361,32 @@ class DictPersistence(BasePersistence):
self._callback_data = (data[0], data[1].copy()) self._callback_data = (data[0], data[1].copy())
self._callback_data_json = None self._callback_data_json = None
def drop_chat_data(self, chat_id: int) -> None:
"""Will delete the specified key from the :attr:`chat_data`.
.. versionadded:: 14.0
Args:
chat_id (:obj:`int`): The chat id to delete from the persistence.
"""
if self._chat_data is None:
return
self._chat_data.pop(chat_id, None)
self._chat_data_json = None
def drop_user_data(self, user_id: int) -> None:
"""Will delete the specified key from the :attr:`user_data`.
.. versionadded:: 14.0
Args:
user_id (:obj:`int`): The user id to delete from the persistence.
"""
if self._user_data is None:
return
self._user_data.pop(user_id, None)
self._user_data_json = None
def refresh_user_data(self, user_id: int, user_data: Dict) -> None: def refresh_user_data(self, user_id: int, user_data: Dict) -> None:
"""Does nothing. """Does nothing.

View file

@ -37,7 +37,9 @@ from typing import (
TypeVar, TypeVar,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
Mapping,
) )
from types import MappingProxyType
from uuid import uuid4 from uuid import uuid4
from telegram import Update from telegram import Update
@ -78,11 +80,11 @@ class DispatcherHandlerStop(Exception):
Note: Note:
Has no effect, if the handler or error handler is run asynchronously. Has no effect, if the handler or error handler is run asynchronously.
Attributes:
state (:obj:`object`): Optional. The next state of the conversation.
Args: Args:
state (:obj:`object`, optional): The next state of the conversation. state (:obj:`object`, optional): The next state of the conversation.
Attributes:
state (:obj:`object`): Optional. The next state of the conversation.
""" """
__slots__ = ('state',) __slots__ = ('state',)
@ -111,8 +113,24 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
instance to pass onto handler callbacks. instance to pass onto handler callbacks.
workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the
``@run_async`` decorator and :meth:`run_async`. ``@run_async`` decorator and :meth:`run_async`.
user_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the user. chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for
chat_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the chat. the chat.
.. versionchanged:: 14.0
:attr:`chat_data` is now read-only
.. tip::
Manually modifying :attr:`chat_data` is almost never needed and unadvisable.
user_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for
the user.
.. versionchanged:: 14.0
:attr:`user_data` is now read-only
.. tip::
Manually modifying :attr:`user_data` is almost never needed and unadvisable.
bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts. store data that should be persistent over restarts.
@ -144,7 +162,9 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
'persistence', 'persistence',
'update_queue', 'update_queue',
'job_queue', 'job_queue',
'_user_data',
'user_data', 'user_data',
'_chat_data',
'chat_data', 'chat_data',
'bot_data', 'bot_data',
'_update_persistence_lock', '_update_persistence_lock',
@ -198,10 +218,15 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
stacklevel=stack_level, stacklevel=stack_level,
) )
self.user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data)
self.chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data)
# Read only mapping-
self.user_data: Mapping[int, UD] = MappingProxyType(self._user_data)
self.chat_data: Mapping[int, CD] = MappingProxyType(self._chat_data)
self.bot_data = self.context_types.bot_data() self.bot_data = self.context_types.bot_data()
self.persistence: Optional[BasePersistence] = None
self.persistence: Optional[BasePersistence]
self._update_persistence_lock = Lock() self._update_persistence_lock = Lock()
if persistence: if persistence:
if not isinstance(persistence, BasePersistence): if not isinstance(persistence, BasePersistence):
@ -213,13 +238,9 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
self.persistence.set_bot(self.bot) self.persistence.set_bot(self.bot)
if self.persistence.store_data.user_data: if self.persistence.store_data.user_data:
self.user_data = self.persistence.get_user_data() self._user_data.update(self.persistence.get_user_data())
if not isinstance(self.user_data, defaultdict):
raise ValueError("user_data must be of type defaultdict")
if self.persistence.store_data.chat_data: if self.persistence.store_data.chat_data:
self.chat_data = self.persistence.get_chat_data() self._chat_data.update(self.persistence.get_chat_data())
if not isinstance(self.chat_data, defaultdict):
raise ValueError("chat_data must be of type defaultdict")
if self.persistence.store_data.bot_data: if self.persistence.store_data.bot_data:
self.bot_data = self.persistence.get_bot_data() self.bot_data = self.persistence.get_bot_data()
if not isinstance(self.bot_data, self.context_types.bot_data): if not isinstance(self.bot_data, self.context_types.bot_data):
@ -230,7 +251,7 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
persistent_data = self.persistence.get_callback_data() persistent_data = self.persistence.get_callback_data()
if persistent_data is not None: if persistent_data is not None:
if not isinstance(persistent_data, tuple) and len(persistent_data) != 2: if not isinstance(persistent_data, tuple) and len(persistent_data) != 2:
raise ValueError('callback_data must be a 2-tuple') raise ValueError('callback_data must be a tuple of length 2')
# Mypy doesn't know that persistence.set_bot (see above) already checks that # Mypy doesn't know that persistence.set_bot (see above) already checks that
# self.bot is an instance of ExtBot if callback_data should be stored ... # self.bot is an instance of ExtBot if callback_data should be stored ...
self.bot.callback_data_cache = CallbackDataCache( # type: ignore[attr-defined] self.bot.callback_data_cache = CallbackDataCache( # type: ignore[attr-defined]
@ -631,6 +652,34 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
if not self.handlers[group]: if not self.handlers[group]:
del self.handlers[group] del self.handlers[group]
def drop_chat_data(self, chat_id: int) -> None:
"""Used for deleting a key from the :attr:`chat_data`.
.. versionadded:: 14.0
Args:
chat_id (:obj:`int`): The chat id to delete from the persistence. The entry
will be deleted even if it is not empty.
"""
self._chat_data.pop(chat_id, None) # type: ignore[arg-type]
if self.persistence:
self.persistence.drop_chat_data(chat_id)
def drop_user_data(self, user_id: int) -> None:
"""Used for deleting a key from the :attr:`user_data`.
.. versionadded:: 14.0
Args:
user_id (:obj:`int`): The user id to delete from the persistence. The entry
will be deleted even if it is not empty.
"""
self._user_data.pop(user_id, None) # type: ignore[arg-type]
if self.persistence:
self.persistence.drop_user_data(user_id)
def update_persistence(self, update: object = None) -> None: def update_persistence(self, update: object = None) -> None:
"""Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`. """Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`.
@ -643,7 +692,7 @@ class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]):
def __update_persistence(self, update: object = None) -> None: def __update_persistence(self, update: object = None) -> None:
if self.persistence: if self.persistence:
# We use list() here in order to decouple chat_ids from self.chat_data, as dict view # We use list() here in order to decouple chat_ids from self._chat_data, as dict view
# objects will change, when the dict does and we want to loop over chat_ids # objects will change, when the dict does and we want to loop over chat_ids
chat_ids = list(self.chat_data.keys()) chat_ids = list(self.chat_data.keys())
user_ids = list(self.user_data.keys()) user_ids = list(self.user_data.keys())

View file

@ -392,6 +392,44 @@ class PicklePersistence(BasePersistence[UD, CD, BD]):
else: else:
self._dump_singlefile() self._dump_singlefile()
def drop_chat_data(self, chat_id: int) -> None:
"""Will delete the specified key from the :attr:`chat_data` and depending on
:attr:`on_flush` save the pickle file.
.. versionadded:: 14.0
Args:
chat_id (:obj:`int`): The chat id to delete from the persistence.
"""
if self.chat_data is None:
return
self.chat_data.pop(chat_id, None) # type: ignore[arg-type]
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_chat_data"), self.chat_data)
else:
self._dump_singlefile()
def drop_user_data(self, user_id: int) -> None:
"""Will delete the specified key from the :attr:`user_data` and depending on
:attr:`on_flush` save the pickle file.
.. versionadded:: 14.0
Args:
user_id (:obj:`int`): The user id to delete from the persistence.
"""
if self.user_data is None:
return
self.user_data.pop(user_id, None) # type: ignore[arg-type]
if not self.on_flush:
if not self.single_file:
self._dump_file(Path(f"{self.filepath}_user_data"), self.user_data)
else:
self._dump_singlefile()
def refresh_user_data(self, user_id: int, user_data: UD) -> None: def refresh_user_data(self, user_id: int, user_data: UD) -> None:
"""Does nothing. """Does nothing.

View file

@ -28,6 +28,7 @@ from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from time import sleep from time import sleep
from typing import Callable, List, Iterable, Any from typing import Callable, List, Iterable, Any
from types import MappingProxyType
import pytest import pytest
import pytz import pytz
@ -194,10 +195,11 @@ def dp(_dp):
# Reset the dispatcher first # Reset the dispatcher first
while not _dp.update_queue.empty(): while not _dp.update_queue.empty():
_dp.update_queue.get(False) _dp.update_queue.get(False)
_dp.chat_data = defaultdict(dict) _dp._chat_data = defaultdict(dict)
_dp.user_data = defaultdict(dict) _dp._user_data = defaultdict(dict)
_dp.chat_data = MappingProxyType(_dp._chat_data) # Rebuild the mapping so it updates
_dp.user_data = MappingProxyType(_dp._user_data)
_dp.bot_data = {} _dp.bot_data = {}
_dp.persistence = None
_dp.handlers = {} _dp.handlers = {}
_dp.error_handlers = {} _dp.error_handlers = {}
_dp.exception_event = Event() _dp.exception_event = Event()

View file

@ -125,6 +125,15 @@ class TestDispatcher:
) )
assert recwarn[0].filename == __file__, "stacklevel is incorrect!" assert recwarn[0].filename == __file__, "stacklevel is incorrect!"
@pytest.mark.parametrize("data", ["chat_data", "user_data"])
def test_chat_user_data_read_only(self, dp, data):
read_only_data = getattr(dp, data)
writable_data = getattr(dp, f"_{data}")
writable_data[123] = 321
assert read_only_data == writable_data
with pytest.raises(TypeError):
read_only_data[111] = 123
@pytest.mark.parametrize( @pytest.mark.parametrize(
'builder', 'builder',
(DispatcherBuilder(), UpdaterBuilder()), (DispatcherBuilder(), UpdaterBuilder()),
@ -621,6 +630,12 @@ class TestDispatcher:
def update_bot_data(self, data): def update_bot_data(self, data):
raise Exception raise Exception
def drop_chat_data(self, chat_id):
pass
def drop_user_data(self, user_id):
pass
def get_chat_data(self): def get_chat_data(self):
return defaultdict(dict) return defaultdict(dict)
@ -747,6 +762,12 @@ class TestDispatcher:
def update_user_data(self, user_id, data): def update_user_data(self, user_id, data):
self.update(data) self.update(data)
def drop_user_data(self, user_id):
pass
def drop_chat_data(self, chat_id):
pass
def get_chat_data(self): def get_chat_data(self):
pass pass
@ -825,6 +846,12 @@ class TestDispatcher:
def update_conversation(self, name, key, new_state): def update_conversation(self, name, key, new_state):
pass pass
def drop_chat_data(self, chat_id):
pass
def drop_user_data(self, user_id):
pass
def get_conversations(self, name): def get_conversations(self, name):
pass pass
@ -879,6 +906,26 @@ class TestDispatcher:
assert not dp.persistence.test_flag_user_data assert not dp.persistence.test_flag_user_data
assert dp.persistence.test_flag_chat_data assert dp.persistence.test_flag_chat_data
@pytest.mark.parametrize(
"c_id,expected",
[(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})],
ids=["test chat_id removal", "test no key in data (no error)"],
)
def test_drop_chat_data(self, dp, c_id, expected):
dp._chat_data.update({321: {'not_empty': 'no'}, 222: "remove_me"})
dp.drop_chat_data(c_id)
assert dp.chat_data == expected
@pytest.mark.parametrize(
"u_id,expected",
[(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})],
ids=["test user_id removal", "test no key in data (no error)"],
)
def test_drop_user_data(self, dp, u_id, expected):
dp._user_data.update({321: {'not_empty': 'no'}, 222: "remove_me"})
dp.drop_user_data(u_id)
assert dp.user_data == expected
def test_update_persistence_once_per_update(self, monkeypatch, dp): def test_update_persistence_once_per_update(self, monkeypatch, dp):
def update_persistence(*args, **kwargs): def update_persistence(*args, **kwargs):
self.count += 1 self.count += 1

View file

@ -50,14 +50,14 @@ from telegram.ext import (
JobQueue, JobQueue,
ContextTypes, ContextTypes,
) )
from telegram.ext._callbackdatacache import _KeyboardData
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def change_directory(tmp_path): def change_directory(tmp_path: Path):
orig_dir = Path.cwd() orig_dir = Path.cwd()
# Switch to a temporary directory so we don't have to worry about cleaning up files # Switch to a temporary directory, so we don't have to worry about cleaning up files
# (str() for py<3.6) os.chdir(tmp_path)
os.chdir(str(tmp_path))
yield yield
# Go back to original directory # Go back to original directory
os.chdir(orig_dir) os.chdir(orig_dir)
@ -99,6 +99,12 @@ class OwnPersistence(BasePersistence):
def get_callback_data(self): def get_callback_data(self):
raise NotImplementedError raise NotImplementedError
def drop_user_data(self, user_id):
raise NotImplementedError
def drop_chat_data(self, chat_id):
raise NotImplementedError
def refresh_user_data(self, user_id, user_data): def refresh_user_data(self, user_id, user_data):
raise NotImplementedError raise NotImplementedError
@ -159,6 +165,12 @@ def bot_persistence():
def update_callback_data(self, data): def update_callback_data(self, data):
self.callback_data = data self.callback_data = data
def drop_user_data(self, user_id):
pass
def drop_chat_data(self, chat_id):
pass
def update_conversation(self, name, key, new_state): def update_conversation(self, name, key, new_state):
raise NotImplementedError raise NotImplementedError
@ -257,7 +269,7 @@ class TestBasePersistence:
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match=( match=(
'flush, get_bot_data, get_callback_data, ' 'drop_chat_data, drop_user_data, flush, get_bot_data, get_callback_data, '
'get_chat_data, get_conversations, ' 'get_chat_data, get_conversations, '
'get_user_data, refresh_bot_data, refresh_chat_data, ' 'get_user_data, refresh_bot_data, refresh_chat_data, '
'refresh_user_data, update_bot_data, update_callback_data, ' 'refresh_user_data, update_bot_data, update_callback_data, '
@ -284,52 +296,40 @@ class TestBasePersistence:
def test_dispatcher_integration_init( def test_dispatcher_integration_init(
self, bot, base_persistence, chat_data, user_data, bot_data, callback_data self, bot, base_persistence, chat_data, user_data, bot_data, callback_data
): ):
def get_user_data(): # Bad data testing-
def bad_get_bot_data():
return "test" return "test"
def get_chat_data(): def bad_get_callback_data():
return "test" return "test"
def get_bot_data(): # Good data testing-
return "test" def good_get_user_data():
def get_callback_data():
return "test"
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
base_persistence.get_bot_data = get_bot_data
base_persistence.get_callback_data = get_callback_data
with pytest.raises(ValueError, match="user_data must be of type defaultdict"):
UpdaterBuilder().bot(bot).persistence(base_persistence).build()
def get_user_data():
return user_data return user_data
base_persistence.get_user_data = get_user_data def good_get_chat_data():
with pytest.raises(ValueError, match="chat_data must be of type defaultdict"):
UpdaterBuilder().bot(bot).persistence(base_persistence).build()
def get_chat_data():
return chat_data return chat_data
base_persistence.get_chat_data = get_chat_data def good_get_bot_data():
return bot_data
def good_get_callback_data():
return callback_data
base_persistence.get_user_data = good_get_user_data # No errors to be tested so
base_persistence.get_chat_data = good_get_chat_data
base_persistence.get_bot_data = bad_get_bot_data
base_persistence.get_callback_data = bad_get_callback_data
with pytest.raises(ValueError, match="bot_data must be of type dict"): with pytest.raises(ValueError, match="bot_data must be of type dict"):
UpdaterBuilder().bot(bot).persistence(base_persistence).build() UpdaterBuilder().bot(bot).persistence(base_persistence).build()
def get_bot_data(): base_persistence.get_bot_data = good_get_bot_data
return bot_data with pytest.raises(ValueError, match="callback_data must be a tuple of length 2"):
base_persistence.get_bot_data = get_bot_data
with pytest.raises(ValueError, match="callback_data must be a 2-tuple"):
UpdaterBuilder().bot(bot).persistence(base_persistence).build() UpdaterBuilder().bot(bot).persistence(base_persistence).build()
def get_callback_data():
return callback_data
base_persistence.bot = None base_persistence.bot = None
base_persistence.get_callback_data = get_callback_data base_persistence.get_callback_data = good_get_callback_data
u = UpdaterBuilder().bot(bot).persistence(base_persistence).build() u = UpdaterBuilder().bot(bot).persistence(base_persistence).build()
assert u.dispatcher.bot is base_persistence.bot assert u.dispatcher.bot is base_persistence.bot
assert u.dispatcher.bot_data == bot_data assert u.dispatcher.bot_data == bot_data
@ -339,7 +339,7 @@ class TestBasePersistence:
u.dispatcher.chat_data[442233]['test5'] = 'test6' u.dispatcher.chat_data[442233]['test5'] = 'test6'
assert u.dispatcher.chat_data[442233]['test5'] == 'test6' assert u.dispatcher.chat_data[442233]['test5'] == 'test6'
@pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async']) @pytest.mark.parametrize('run_async', [True, False], ids=['run_async', 'synchronous'])
def test_dispatcher_integration_handlers( def test_dispatcher_integration_handlers(
self, self,
dp, dp,
@ -368,8 +368,6 @@ class TestBasePersistence:
base_persistence.get_chat_data = get_chat_data base_persistence.get_chat_data = get_chat_data
base_persistence.get_bot_data = get_bot_data base_persistence.get_bot_data = get_bot_data
base_persistence.get_callback_data = get_callback_data base_persistence.get_callback_data = get_callback_data
# base_persistence.update_chat_data = lambda x: x
# base_persistence.update_user_data = lambda x: x
base_persistence.refresh_bot_data = lambda x: x base_persistence.refresh_bot_data = lambda x: x
base_persistence.refresh_chat_data = lambda x, y: x base_persistence.refresh_chat_data = lambda x, y: x
base_persistence.refresh_user_data = lambda x, y: x base_persistence.refresh_user_data = lambda x, y: x
@ -383,7 +381,7 @@ class TestBasePersistence:
pytest.fail('bot_data corrupt') pytest.fail('bot_data corrupt')
def callback_known_chat(update, context): def callback_known_chat(update, context):
if not context.chat_data['test3'] == 'test4': if not context.chat_data[3] == 'test4':
pytest.fail('chat_data corrupt') pytest.fail('chat_data corrupt')
if not context.bot_data == bot_data: if not context.bot_data == bot_data:
pytest.fail('bot_data corrupt') pytest.fail('bot_data corrupt')
@ -398,20 +396,17 @@ class TestBasePersistence:
context.user_data[1] = 'test7' context.user_data[1] = 'test7'
context.chat_data[2] = 'test8' context.chat_data[2] = 'test8'
context.bot_data['test0'] = 'test0' context.bot_data['test0'] = 'test0'
context.bot.callback_data_cache.put('test0') # Let's now delete user1 and chat1
context.dispatcher.drop_chat_data(-67890)
context.dispatcher.drop_user_data(12345)
# Test setting new keyboard callback data-
context.bot.callback_data_cache._keyboard_data['id'] = _KeyboardData(
'id', button_data={'button3': 'test3'}
)
known_user = MessageHandler( known_user = MessageHandler(filters.User(user_id=12345), callback_known_user) # user1
filters.User(user_id=12345), known_chat = MessageHandler(filters.Chat(chat_id=-67890), callback_known_chat) # chat1
callback_known_user, unknown = MessageHandler(filters.ALL, callback_unknown_user_or_chat) # user2 and chat2
)
known_chat = MessageHandler(
filters.Chat(chat_id=-67890),
callback_known_chat,
)
unknown = MessageHandler(
filters.ALL,
callback_unknown_user_or_chat,
)
dp.add_handler(known_user) dp.add_handler(known_user)
dp.add_handler(known_chat) dp.add_handler(known_chat)
dp.add_handler(unknown) dp.add_handler(unknown)
@ -420,51 +415,64 @@ class TestBasePersistence:
chat1 = Chat(id=-67890, type='group') chat1 = Chat(id=-67890, type='group')
chat2 = Chat(id=-987654, type='group') chat2 = Chat(id=-987654, type='group')
m = Message(1, None, chat2, from_user=user1) m = Message(1, None, chat2, from_user=user1)
u = Update(0, m) u_known_user = Update(0, m)
with caplog.at_level(logging.ERROR): dp.process_update(u_known_user)
dp.process_update(u) # 4 errors which arise since update_*_data are raising NotImplementedError here.
rec = caplog.records[-1] assert len(caplog.records) == 4
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
rec = caplog.records[-2]
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
rec = caplog.records[-3]
assert rec.getMessage() == 'No error handlers are registered, logging exception.'
assert rec.levelname == 'ERROR'
m.from_user = user2 m.from_user = user2
m.chat = chat1 m.chat = chat1
u = Update(1, m) u_known_chat = Update(1, m)
dp.process_update(u) dp.process_update(u_known_chat)
m.chat = chat2 m.chat = chat2
u = Update(2, m) u_unknown_user_or_chat = Update(2, m)
def save_bot_data(data): def save_bot_data(data):
if 'test0' not in data: if 'test0' not in data:
pytest.fail() pytest.fail()
def save_chat_data(data): def save_chat_data(_id, data):
if -987654 not in data: if 2 not in data: # data should be: {2: 'test8'}
pytest.fail() pytest.fail()
def save_user_data(data): def save_user_data(_id, data):
if 54321 not in data: if 1 not in data: # data should be: {1: 'test7'}
pytest.fail() pytest.fail()
def save_callback_data(data): def save_callback_data(data):
if not assert_data_in_cache(dp.bot.callback_data, 'test0'): if not assert_data_in_cache(dp.bot.callback_data_cache, 'test3'):
pytest.fail() pytest.fail()
# Functions to check deletion-
def delete_user_data(user_id):
if 12345 != user_id:
pytest.fail("The id being deleted is not of user1's")
user_data.pop(user_id, None)
def delete_chat_data(chat_id):
if -67890 != chat_id:
pytest.fail("The chat id being deleted is not of chat1's")
chat_data.pop(chat_id, None)
base_persistence.update_chat_data = save_chat_data base_persistence.update_chat_data = save_chat_data
base_persistence.update_user_data = save_user_data base_persistence.update_user_data = save_user_data
base_persistence.update_bot_data = save_bot_data base_persistence.update_bot_data = save_bot_data
base_persistence.update_callback_data = save_callback_data base_persistence.update_callback_data = save_callback_data
dp.process_update(u) base_persistence.drop_chat_data = delete_chat_data
base_persistence.drop_user_data = delete_user_data
dp.process_update(u_unknown_user_or_chat)
# Test callback_unknown_user_or_chat worked correctly-
assert dp.user_data[54321][1] == 'test7' assert dp.user_data[54321][1] == 'test7'
assert dp.chat_data[-987654][2] == 'test8' assert dp.chat_data[-987654][2] == 'test8'
assert dp.bot_data['test0'] == 'test0' assert dp.bot_data['test0'] == 'test0'
assert assert_data_in_cache(dp.bot.callback_data_cache, 'test0') assert assert_data_in_cache(dp.bot.callback_data_cache, 'test3')
assert 12345 not in dp.user_data # Tests if dp.drop_user_data worked or not
assert -67890 not in dp.chat_data
assert len(caplog.records) == 8 # Errors double since new update is processed.
for r in caplog.records:
assert issubclass(r.exc_info[0], NotImplementedError)
assert r.getMessage() == 'No error handlers are registered, logging exception.'
assert r.levelname == 'ERROR'
@pytest.mark.parametrize( @pytest.mark.parametrize(
'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False'] 'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False']
@ -475,7 +483,7 @@ class TestBasePersistence:
@pytest.mark.parametrize( @pytest.mark.parametrize(
'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False'] 'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False']
) )
@pytest.mark.parametrize('run_async', [True, False], ids=['synchronous', 'run_async']) @pytest.mark.parametrize('run_async', [True, False], ids=['run_async', 'synchronous'])
def test_persistence_dispatcher_integration_refresh_data( def test_persistence_dispatcher_integration_refresh_data(
self, self,
dp, dp,
@ -1042,15 +1050,11 @@ class TestPicklePersistence:
assert retrieved == bot_data assert retrieved == bot_data
def test_no_files_present_multi_file(self, pickle_persistence): def test_no_files_present_multi_file(self, pickle_persistence):
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_user_data() == defaultdict(dict) assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict) assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_bot_data() == {}
assert pickle_persistence.get_bot_data() == {} assert pickle_persistence.get_bot_data() == {}
assert pickle_persistence.get_callback_data() is None assert pickle_persistence.get_callback_data() is None
assert pickle_persistence.get_conversations('noname') == {} assert pickle_persistence.get_conversations('noname') == {}
assert pickle_persistence.get_conversations('noname') == {}
def test_no_files_present_single_file(self, pickle_persistence): def test_no_files_present_single_file(self, pickle_persistence):
pickle_persistence.single_file = True pickle_persistence.single_file = True
@ -1342,6 +1346,8 @@ class TestPicklePersistence:
with Path('pickletest_user_data').open('rb') as f: with Path('pickletest_user_data').open('rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)) user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data assert user_data_test == user_data
pickle_persistence.drop_user_data(67890)
assert 67890 not in pickle_persistence.get_user_data()
chat_data = pickle_persistence.get_chat_data() chat_data = pickle_persistence.get_chat_data()
chat_data[-12345]['test3']['test4'] = 'test6' chat_data[-12345]['test3']['test4'] = 'test6'
@ -1354,6 +1360,8 @@ class TestPicklePersistence:
with Path('pickletest_chat_data').open('rb') as f: with Path('pickletest_chat_data').open('rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)) chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data assert chat_data_test == chat_data
pickle_persistence.drop_chat_data(-67890)
assert -67890 not in pickle_persistence.get_chat_data()
bot_data = pickle_persistence.get_bot_data() bot_data = pickle_persistence.get_bot_data()
bot_data['test3']['test4'] = 'test6' bot_data['test3']['test4'] = 'test6'
@ -1408,6 +1416,8 @@ class TestPicklePersistence:
with Path('pickletest').open('rb') as f: with Path('pickletest').open('rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data']) user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data assert user_data_test == user_data
pickle_persistence.drop_user_data(67890)
assert 67890 not in pickle_persistence.get_user_data()
chat_data = pickle_persistence.get_chat_data() chat_data = pickle_persistence.get_chat_data()
chat_data[-12345]['test3']['test4'] = 'test6' chat_data[-12345]['test3']['test4'] = 'test6'
@ -1420,6 +1430,8 @@ class TestPicklePersistence:
with Path('pickletest').open('rb') as f: with Path('pickletest').open('rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data assert chat_data_test == chat_data
pickle_persistence.drop_chat_data(-67890)
assert -67890 not in pickle_persistence.get_chat_data()
bot_data = pickle_persistence.get_bot_data() bot_data = pickle_persistence.get_bot_data()
bot_data['test3']['test4'] = 'test6' bot_data['test3']['test4'] = 'test6'
@ -1487,6 +1499,9 @@ class TestPicklePersistence:
pickle_persistence.update_user_data(54321, user_data[54321]) pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data assert pickle_persistence.user_data == user_data
pickle_persistence.drop_user_data(0)
assert pickle_persistence.user_data == user_data
with Path('pickletest_user_data').open('rb') as f: with Path('pickletest_user_data').open('rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)) user_data_test = defaultdict(dict, pickle.load(f))
assert not user_data_test == user_data assert not user_data_test == user_data
@ -1498,6 +1513,9 @@ class TestPicklePersistence:
pickle_persistence.update_chat_data(54321, chat_data[54321]) pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data assert pickle_persistence.chat_data == chat_data
pickle_persistence.drop_chat_data(0)
assert pickle_persistence.user_data == user_data
with Path('pickletest_chat_data').open('rb') as f: with Path('pickletest_chat_data').open('rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)) chat_data_test = defaultdict(dict, pickle.load(f))
assert not chat_data_test == chat_data assert not chat_data_test == chat_data
@ -1905,6 +1923,12 @@ class TestPicklePersistence:
assert isinstance(persistence.get_bot_data(), bd) assert isinstance(persistence.get_bot_data(), bd)
assert persistence.get_bot_data() == 0 assert persistence.get_bot_data() == 0
persistence.user_data = None
persistence.chat_data = None
persistence.drop_user_data(123)
persistence.drop_chat_data(123)
assert isinstance(persistence.get_user_data(), defaultdict)
assert isinstance(persistence.get_chat_data(), defaultdict)
persistence.user_data = None persistence.user_data = None
persistence.chat_data = None persistence.chat_data = None
persistence.update_user_data(1, ud(1)) persistence.update_user_data(1, ud(1))
@ -2132,6 +2156,11 @@ class TestDictPersistence:
dict_persistence.update_user_data(12345, user_data[12345]) dict_persistence.update_user_data(12345, user_data[12345])
assert dict_persistence.user_data == user_data assert dict_persistence.user_data == user_data
assert dict_persistence.user_data_json == json.dumps(user_data) assert dict_persistence.user_data_json == json.dumps(user_data)
dict_persistence.drop_user_data(67890)
assert 67890 not in dict_persistence.user_data
dict_persistence._user_data = None
dict_persistence.drop_user_data(123)
assert isinstance(dict_persistence.get_user_data(), defaultdict)
chat_data = dict_persistence.get_chat_data() chat_data = dict_persistence.get_chat_data()
chat_data[-12345]['test3']['test4'] = 'test6' chat_data[-12345]['test3']['test4'] = 'test6'
@ -2144,6 +2173,11 @@ class TestDictPersistence:
dict_persistence.update_chat_data(-12345, chat_data[-12345]) dict_persistence.update_chat_data(-12345, chat_data[-12345])
assert dict_persistence.chat_data == chat_data assert dict_persistence.chat_data == chat_data
assert dict_persistence.chat_data_json == json.dumps(chat_data) assert dict_persistence.chat_data_json == json.dumps(chat_data)
dict_persistence.drop_chat_data(-67890)
assert -67890 not in dict_persistence.chat_data
dict_persistence._chat_data = None
dict_persistence.drop_chat_data(123)
assert isinstance(dict_persistence.get_chat_data(), defaultdict)
bot_data = dict_persistence.get_bot_data() bot_data = dict_persistence.get_bot_data()
bot_data['test3']['test4'] = 'test6' bot_data['test3']['test4'] = 'test6'