Tweak persistence handling (#1827)

* Unify persistence updates in dispatcher

* Ensure user/chat_data is not None when updating it

* Update persistence after job runs

* Increase coverage
This commit is contained in:
Bibo-Joshi 2020-04-10 13:23:13 +02:00 committed by GitHub
parent bdf0cb91f3
commit f379f54d5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 56 deletions

View file

@ -226,6 +226,8 @@ class DictPersistence(BasePersistence):
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
"""
if self._user_data is None:
self._user_data = defaultdict(dict)
if self._user_data.get(user_id) == data:
return
self._user_data[user_id] = data
@ -238,6 +240,8 @@ class DictPersistence(BasePersistence):
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
"""
if self._chat_data is None:
self._chat_data = defaultdict(dict)
if self._chat_data.get(chat_id) == data:
return
self._chat_data[chat_id] = data

View file

@ -323,53 +323,6 @@ class Dispatcher(object):
"""
def persist_update(update):
"""Persist a single update.
Args:
update (:class:`telegram.Update`):
The update to process.
"""
if self.persistence and isinstance(update, Update):
if self.persistence.store_bot_data:
try:
self.persistence.update_bot_data(self.bot_data)
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving bot data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
if self.persistence.store_chat_data and update.effective_chat:
chat_id = update.effective_chat.id
try:
self.persistence.update_chat_data(chat_id,
self.chat_data[chat_id])
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving chat data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
if self.persistence.store_user_data and update.effective_user:
user_id = update.effective_user.id
try:
self.persistence.update_user_data(user_id,
self.user_data[user_id])
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving user data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
# An error happened while polling
if isinstance(update, TelegramError):
try:
@ -388,13 +341,13 @@ class Dispatcher(object):
if not context and self.use_context:
context = CallbackContext.from_update(update, self)
handler.handle_update(update, self, check, context)
persist_update(update)
self.update_persistence(update=update)
break
# Stop processing with any other handler.
except DispatcherHandlerStop:
self.logger.debug('Stopping further handlers due to DispatcherHandlerStop')
persist_update(update)
self.update_persistence(update=update)
break
# Dispatch any error.
@ -471,18 +424,62 @@ class Dispatcher(object):
del self.handlers[group]
self.groups.remove(group)
def update_persistence(self):
def update_persistence(self, update=None):
"""Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`.
Args:
update (:class:`telegram.Update`, optional): The update to process. If passed, only the
corresponding ``user_data`` and ``chat_data`` will be updated.
"""
if self.persistence:
chat_ids = self.chat_data.keys()
user_ids = self.user_data.keys()
if isinstance(update, Update):
if update.effective_chat:
chat_ids = [update.effective_chat.id]
else:
chat_ids = []
if update.effective_user:
user_ids = [update.effective_user.id]
else:
user_ids = []
if self.persistence.store_bot_data:
self.persistence.update_bot_data(self.bot_data)
try:
self.persistence.update_bot_data(self.bot_data)
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving bot data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
if self.persistence.store_chat_data:
for chat_id in self.chat_data:
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
for chat_id in chat_ids:
try:
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving chat data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
if self.persistence.store_user_data:
for user_id in self.user_data:
self.persistence.update_user_data(user_id, self.user_data[user_id])
for user_id in user_ids:
try:
self.persistence.update_user_data(user_id, self.user_data[user_id])
except Exception as e:
try:
self.dispatch_error(update, e)
except Exception:
message = 'Saving user data raised an error and an ' \
'uncaught error was raised while handling ' \
'the error with an error_handler'
self.logger.exception(message)
def add_error_handler(self, callback):
"""Registers an error handler in the Dispatcher. This handler will receive every error

View file

@ -290,6 +290,7 @@ class JobQueue(object):
if current_week_day in job.days:
self.logger.debug('Running job %s', job.name)
job.run(self._dispatcher)
self._dispatcher.update_persistence()
except Exception:
self.logger.exception('An uncaught error was raised while executing job %s',

View file

@ -224,6 +224,8 @@ class PicklePersistence(BasePersistence):
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
"""
if self.user_data is None:
self.user_data = defaultdict(dict)
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
@ -242,6 +244,8 @@ class PicklePersistence(BasePersistence):
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
"""
if self.chat_data is None:
self.chat_data = defaultdict(dict)
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data

View file

@ -449,3 +449,88 @@ class TestDispatcher(object):
with pytest.warns(TelegramDeprecationWarning):
Dispatcher(dp.bot, dp.update_queue, job_queue=dp.job_queue, workers=0,
use_context=False)
def test_error_while_persisting(self, cdp, monkeypatch):
class OwnPersistence(BasePersistence):
def __init__(self):
super(OwnPersistence, self).__init__()
self.store_user_data = True
self.store_chat_data = True
self.store_bot_data = True
def update(self, data):
raise Exception('PersistenceError')
def update_bot_data(self, data):
self.update(data)
def update_chat_data(self, chat_id, data):
self.update(data)
def update_user_data(self, user_id, data):
self.update(data)
def callback(update, context):
pass
test_flag = False
def error(update, context):
nonlocal test_flag
test_flag = str(context.error) == 'PersistenceError'
raise Exception('ErrorHandlingError')
def logger(message):
assert 'uncaught error was raised while handling' in message
update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), text='Text'))
handler = MessageHandler(Filters.all, callback)
cdp.add_handler(handler)
cdp.add_error_handler(error)
monkeypatch.setattr(cdp.logger, 'exception', logger)
cdp.persistence = OwnPersistence()
cdp.process_update(update)
assert test_flag
def test_persisting_no_user_no_chat(self, cdp):
class OwnPersistence(BasePersistence):
def __init__(self):
super(OwnPersistence, self).__init__()
self.store_user_data = True
self.store_chat_data = True
self.store_bot_data = True
self.test_flag_bot_data = False
self.test_flag_chat_data = False
self.test_flag_user_data = False
def update_bot_data(self, data):
self.test_flag_bot_data = True
def update_chat_data(self, chat_id, data):
self.test_flag_chat_data = True
def update_user_data(self, user_id, data):
self.test_flag_user_data = True
def callback(update, context):
pass
handler = MessageHandler(Filters.all, callback)
cdp.add_handler(handler)
cdp.persistence = OwnPersistence()
update = Update(1, message=Message(1, User(1, '', False), None, None, text='Text'))
cdp.process_update(update)
assert cdp.persistence.test_flag_bot_data
assert cdp.persistence.test_flag_user_data
assert not cdp.persistence.test_flag_chat_data
cdp.persistence.test_flag_bot_data = False
cdp.persistence.test_flag_user_data = False
cdp.persistence.test_flag_chat_data = False
update = Update(1, message=Message(1, None, None, Chat(1, ''), text='Text'))
cdp.process_update(update)
assert cdp.persistence.test_flag_bot_data
assert not cdp.persistence.test_flag_user_data
assert cdp.persistence.test_flag_chat_data

View file

@ -29,12 +29,13 @@ import logging
import os
import pickle
from collections import defaultdict
from time import sleep
import pytest
from telegram import Update, Message, User, Chat, MessageEntity
from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \
PicklePersistence, CommandHandler, DictPersistence, TypeHandler
PicklePersistence, CommandHandler, DictPersistence, TypeHandler, JobQueue
@pytest.fixture(autouse=True)
@ -87,6 +88,13 @@ def updater(bot, base_persistence):
return u
@pytest.fixture(scope='function')
def job_queue(bot):
jq = JobQueue()
yield jq
jq.stop()
class TestBasePersistence(object):
def test_creation(self, base_persistence):
@ -920,6 +928,24 @@ class TestPickelPersistence(object):
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
assert nested_ch.conversations == pickle_persistence.conversations['name3']
def test_with_job(self, job_queue, cdp, pickle_persistence):
def job_callback(context):
context.bot_data['test1'] = '456'
context.dispatcher.chat_data[123]['test2'] = '789'
context.dispatcher.user_data[789]['test3'] = '123'
cdp.persistence = pickle_persistence
job_queue.set_dispatcher(cdp)
job_queue.start()
job_queue.run_once(job_callback, 0.01)
sleep(0.05)
bot_data = pickle_persistence.get_bot_data()
assert bot_data == {'test1': '456'}
chat_data = pickle_persistence.get_chat_data()
assert chat_data[123] == {'test2': '789'}
user_data = pickle_persistence.get_user_data()
assert user_data[789] == {'test3': '123'}
@pytest.fixture(scope='function')
def user_data_json(user_data):
@ -1202,3 +1228,22 @@ class TestDictPersistence(object):
assert ch.conversations == dict_persistence.conversations['name2']
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
assert nested_ch.conversations == dict_persistence.conversations['name3']
def test_with_job(self, job_queue, cdp):
def job_callback(context):
context.bot_data['test1'] = '456'
context.dispatcher.chat_data[123]['test2'] = '789'
context.dispatcher.user_data[789]['test3'] = '123'
dict_persistence = DictPersistence()
cdp.persistence = dict_persistence
job_queue.set_dispatcher(cdp)
job_queue.start()
job_queue.run_once(job_callback, 0.01)
sleep(0.05)
bot_data = dict_persistence.get_bot_data()
assert bot_data == {'test1': '456'}
chat_data = dict_persistence.get_chat_data()
assert chat_data[123] == {'test2': '789'}
user_data = dict_persistence.get_user_data()
assert user_data[789] == {'test3': '123'}