mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-17 04:39:55 +01:00
Tweak persistence handling (#1827)
* Unify persistence updates in dispatcher * Ensure user/chat_data is not None when updating it * Update persistence after job runs * Increase coverage
This commit is contained in:
parent
bdf0cb91f3
commit
f379f54d5a
6 changed files with 192 additions and 56 deletions
|
@ -226,6 +226,8 @@ class DictPersistence(BasePersistence):
|
|||
user_id (:obj:`int`): The user the data might have been changed for.
|
||||
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
|
||||
"""
|
||||
if self._user_data is None:
|
||||
self._user_data = defaultdict(dict)
|
||||
if self._user_data.get(user_id) == data:
|
||||
return
|
||||
self._user_data[user_id] = data
|
||||
|
@ -238,6 +240,8 @@ class DictPersistence(BasePersistence):
|
|||
chat_id (:obj:`int`): The chat the data might have been changed for.
|
||||
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
|
||||
"""
|
||||
if self._chat_data is None:
|
||||
self._chat_data = defaultdict(dict)
|
||||
if self._chat_data.get(chat_id) == data:
|
||||
return
|
||||
self._chat_data[chat_id] = data
|
||||
|
|
|
@ -323,53 +323,6 @@ class Dispatcher(object):
|
|||
|
||||
"""
|
||||
|
||||
def persist_update(update):
|
||||
"""Persist a single update.
|
||||
|
||||
Args:
|
||||
update (:class:`telegram.Update`):
|
||||
The update to process.
|
||||
|
||||
"""
|
||||
if self.persistence and isinstance(update, Update):
|
||||
if self.persistence.store_bot_data:
|
||||
try:
|
||||
self.persistence.update_bot_data(self.bot_data)
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving bot data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
if self.persistence.store_chat_data and update.effective_chat:
|
||||
chat_id = update.effective_chat.id
|
||||
try:
|
||||
self.persistence.update_chat_data(chat_id,
|
||||
self.chat_data[chat_id])
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving chat data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
if self.persistence.store_user_data and update.effective_user:
|
||||
user_id = update.effective_user.id
|
||||
try:
|
||||
self.persistence.update_user_data(user_id,
|
||||
self.user_data[user_id])
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving user data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
|
||||
# An error happened while polling
|
||||
if isinstance(update, TelegramError):
|
||||
try:
|
||||
|
@ -388,13 +341,13 @@ class Dispatcher(object):
|
|||
if not context and self.use_context:
|
||||
context = CallbackContext.from_update(update, self)
|
||||
handler.handle_update(update, self, check, context)
|
||||
persist_update(update)
|
||||
self.update_persistence(update=update)
|
||||
break
|
||||
|
||||
# Stop processing with any other handler.
|
||||
except DispatcherHandlerStop:
|
||||
self.logger.debug('Stopping further handlers due to DispatcherHandlerStop')
|
||||
persist_update(update)
|
||||
self.update_persistence(update=update)
|
||||
break
|
||||
|
||||
# Dispatch any error.
|
||||
|
@ -471,18 +424,62 @@ class Dispatcher(object):
|
|||
del self.handlers[group]
|
||||
self.groups.remove(group)
|
||||
|
||||
def update_persistence(self):
|
||||
def update_persistence(self, update=None):
|
||||
"""Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`.
|
||||
|
||||
Args:
|
||||
update (:class:`telegram.Update`, optional): The update to process. If passed, only the
|
||||
corresponding ``user_data`` and ``chat_data`` will be updated.
|
||||
"""
|
||||
if self.persistence:
|
||||
chat_ids = self.chat_data.keys()
|
||||
user_ids = self.user_data.keys()
|
||||
|
||||
if isinstance(update, Update):
|
||||
if update.effective_chat:
|
||||
chat_ids = [update.effective_chat.id]
|
||||
else:
|
||||
chat_ids = []
|
||||
if update.effective_user:
|
||||
user_ids = [update.effective_user.id]
|
||||
else:
|
||||
user_ids = []
|
||||
|
||||
if self.persistence.store_bot_data:
|
||||
self.persistence.update_bot_data(self.bot_data)
|
||||
try:
|
||||
self.persistence.update_bot_data(self.bot_data)
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving bot data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
if self.persistence.store_chat_data:
|
||||
for chat_id in self.chat_data:
|
||||
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
|
||||
for chat_id in chat_ids:
|
||||
try:
|
||||
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving chat data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
if self.persistence.store_user_data:
|
||||
for user_id in self.user_data:
|
||||
self.persistence.update_user_data(user_id, self.user_data[user_id])
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
self.persistence.update_user_data(user_id, self.user_data[user_id])
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dispatch_error(update, e)
|
||||
except Exception:
|
||||
message = 'Saving user data raised an error and an ' \
|
||||
'uncaught error was raised while handling ' \
|
||||
'the error with an error_handler'
|
||||
self.logger.exception(message)
|
||||
|
||||
def add_error_handler(self, callback):
|
||||
"""Registers an error handler in the Dispatcher. This handler will receive every error
|
||||
|
|
|
@ -290,6 +290,7 @@ class JobQueue(object):
|
|||
if current_week_day in job.days:
|
||||
self.logger.debug('Running job %s', job.name)
|
||||
job.run(self._dispatcher)
|
||||
self._dispatcher.update_persistence()
|
||||
|
||||
except Exception:
|
||||
self.logger.exception('An uncaught error was raised while executing job %s',
|
||||
|
|
|
@ -224,6 +224,8 @@ class PicklePersistence(BasePersistence):
|
|||
user_id (:obj:`int`): The user the data might have been changed for.
|
||||
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id].
|
||||
"""
|
||||
if self.user_data is None:
|
||||
self.user_data = defaultdict(dict)
|
||||
if self.user_data.get(user_id) == data:
|
||||
return
|
||||
self.user_data[user_id] = data
|
||||
|
@ -242,6 +244,8 @@ class PicklePersistence(BasePersistence):
|
|||
chat_id (:obj:`int`): The chat the data might have been changed for.
|
||||
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id].
|
||||
"""
|
||||
if self.chat_data is None:
|
||||
self.chat_data = defaultdict(dict)
|
||||
if self.chat_data.get(chat_id) == data:
|
||||
return
|
||||
self.chat_data[chat_id] = data
|
||||
|
|
|
@ -449,3 +449,88 @@ class TestDispatcher(object):
|
|||
with pytest.warns(TelegramDeprecationWarning):
|
||||
Dispatcher(dp.bot, dp.update_queue, job_queue=dp.job_queue, workers=0,
|
||||
use_context=False)
|
||||
|
||||
def test_error_while_persisting(self, cdp, monkeypatch):
|
||||
class OwnPersistence(BasePersistence):
|
||||
def __init__(self):
|
||||
super(OwnPersistence, self).__init__()
|
||||
self.store_user_data = True
|
||||
self.store_chat_data = True
|
||||
self.store_bot_data = True
|
||||
|
||||
def update(self, data):
|
||||
raise Exception('PersistenceError')
|
||||
|
||||
def update_bot_data(self, data):
|
||||
self.update(data)
|
||||
|
||||
def update_chat_data(self, chat_id, data):
|
||||
self.update(data)
|
||||
|
||||
def update_user_data(self, user_id, data):
|
||||
self.update(data)
|
||||
|
||||
def callback(update, context):
|
||||
pass
|
||||
|
||||
test_flag = False
|
||||
|
||||
def error(update, context):
|
||||
nonlocal test_flag
|
||||
test_flag = str(context.error) == 'PersistenceError'
|
||||
raise Exception('ErrorHandlingError')
|
||||
|
||||
def logger(message):
|
||||
assert 'uncaught error was raised while handling' in message
|
||||
|
||||
update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), text='Text'))
|
||||
handler = MessageHandler(Filters.all, callback)
|
||||
cdp.add_handler(handler)
|
||||
cdp.add_error_handler(error)
|
||||
monkeypatch.setattr(cdp.logger, 'exception', logger)
|
||||
|
||||
cdp.persistence = OwnPersistence()
|
||||
cdp.process_update(update)
|
||||
assert test_flag
|
||||
|
||||
def test_persisting_no_user_no_chat(self, cdp):
|
||||
class OwnPersistence(BasePersistence):
|
||||
def __init__(self):
|
||||
super(OwnPersistence, self).__init__()
|
||||
self.store_user_data = True
|
||||
self.store_chat_data = True
|
||||
self.store_bot_data = True
|
||||
self.test_flag_bot_data = False
|
||||
self.test_flag_chat_data = False
|
||||
self.test_flag_user_data = False
|
||||
|
||||
def update_bot_data(self, data):
|
||||
self.test_flag_bot_data = True
|
||||
|
||||
def update_chat_data(self, chat_id, data):
|
||||
self.test_flag_chat_data = True
|
||||
|
||||
def update_user_data(self, user_id, data):
|
||||
self.test_flag_user_data = True
|
||||
|
||||
def callback(update, context):
|
||||
pass
|
||||
|
||||
handler = MessageHandler(Filters.all, callback)
|
||||
cdp.add_handler(handler)
|
||||
cdp.persistence = OwnPersistence()
|
||||
|
||||
update = Update(1, message=Message(1, User(1, '', False), None, None, text='Text'))
|
||||
cdp.process_update(update)
|
||||
assert cdp.persistence.test_flag_bot_data
|
||||
assert cdp.persistence.test_flag_user_data
|
||||
assert not cdp.persistence.test_flag_chat_data
|
||||
|
||||
cdp.persistence.test_flag_bot_data = False
|
||||
cdp.persistence.test_flag_user_data = False
|
||||
cdp.persistence.test_flag_chat_data = False
|
||||
update = Update(1, message=Message(1, None, None, Chat(1, ''), text='Text'))
|
||||
cdp.process_update(update)
|
||||
assert cdp.persistence.test_flag_bot_data
|
||||
assert not cdp.persistence.test_flag_user_data
|
||||
assert cdp.persistence.test_flag_chat_data
|
||||
|
|
|
@ -29,12 +29,13 @@ import logging
|
|||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from telegram import Update, Message, User, Chat, MessageEntity
|
||||
from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \
|
||||
PicklePersistence, CommandHandler, DictPersistence, TypeHandler
|
||||
PicklePersistence, CommandHandler, DictPersistence, TypeHandler, JobQueue
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -87,6 +88,13 @@ def updater(bot, base_persistence):
|
|||
return u
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def job_queue(bot):
|
||||
jq = JobQueue()
|
||||
yield jq
|
||||
jq.stop()
|
||||
|
||||
|
||||
class TestBasePersistence(object):
|
||||
|
||||
def test_creation(self, base_persistence):
|
||||
|
@ -920,6 +928,24 @@ class TestPickelPersistence(object):
|
|||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations == pickle_persistence.conversations['name3']
|
||||
|
||||
def test_with_job(self, job_queue, cdp, pickle_persistence):
|
||||
def job_callback(context):
|
||||
context.bot_data['test1'] = '456'
|
||||
context.dispatcher.chat_data[123]['test2'] = '789'
|
||||
context.dispatcher.user_data[789]['test3'] = '123'
|
||||
|
||||
cdp.persistence = pickle_persistence
|
||||
job_queue.set_dispatcher(cdp)
|
||||
job_queue.start()
|
||||
job_queue.run_once(job_callback, 0.01)
|
||||
sleep(0.05)
|
||||
bot_data = pickle_persistence.get_bot_data()
|
||||
assert bot_data == {'test1': '456'}
|
||||
chat_data = pickle_persistence.get_chat_data()
|
||||
assert chat_data[123] == {'test2': '789'}
|
||||
user_data = pickle_persistence.get_user_data()
|
||||
assert user_data[789] == {'test3': '123'}
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def user_data_json(user_data):
|
||||
|
@ -1202,3 +1228,22 @@ class TestDictPersistence(object):
|
|||
assert ch.conversations == dict_persistence.conversations['name2']
|
||||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations == dict_persistence.conversations['name3']
|
||||
|
||||
def test_with_job(self, job_queue, cdp):
|
||||
def job_callback(context):
|
||||
context.bot_data['test1'] = '456'
|
||||
context.dispatcher.chat_data[123]['test2'] = '789'
|
||||
context.dispatcher.user_data[789]['test3'] = '123'
|
||||
|
||||
dict_persistence = DictPersistence()
|
||||
cdp.persistence = dict_persistence
|
||||
job_queue.set_dispatcher(cdp)
|
||||
job_queue.start()
|
||||
job_queue.run_once(job_callback, 0.01)
|
||||
sleep(0.05)
|
||||
bot_data = dict_persistence.get_bot_data()
|
||||
assert bot_data == {'test1': '456'}
|
||||
chat_data = dict_persistence.get_chat_data()
|
||||
assert chat_data[123] == {'test2': '789'}
|
||||
user_data = dict_persistence.get_user_data()
|
||||
assert user_data[789] == {'test3': '123'}
|
||||
|
|
Loading…
Add table
Reference in a new issue