Persistence (#1017)

* BasePersistence

* basic construct

* Keep working

* Continue work

Add tests for Basepersistence

* Finish up BasePersistence and implementation

* PickelPersistence and start tests

* Finishing up

* Oops, left in some typings

* Compatibilty issues regarding py2 solved

For Py2 compatibility

* increasing coverage

* Small changes due to CR

* All persistence tests in one file

* add DictPersistence

* Last changes per CR

* forgot change

* changes per CR

* call update_* only with relevant data

As discussed with @jsmnbom

* Add conversationbot Example

* should not have committed API-key
This commit is contained in:
Eldinnie 2018-09-20 22:50:40 +02:00 committed by Jasmin Bom
parent b9f56ca479
commit 439790375e
17 changed files with 1677 additions and 9 deletions

View file

@ -0,0 +1,6 @@
telegram.ext.BasePersistence
============================
.. autoclass:: telegram.ext.BasePersistence
:members:
:show-inheritance:

View file

@ -0,0 +1,6 @@
telegram.ext.DictPersistence
============================
.. autoclass:: telegram.ext.DictPersistence
:members:
:show-inheritance:

View file

@ -0,0 +1,6 @@
telegram.ext.PicklePersistence
==============================
.. autoclass:: telegram.ext.PicklePersistence
:members:
:show-inheritance:

View file

@ -29,3 +29,12 @@ Handlers
telegram.ext.stringcommandhandler
telegram.ext.stringregexhandler
telegram.ext.typehandler
Persistence
-----------
.. toctree::
telegram.ext.basepersistence
telegram.ext.picklepersistence
telegram.ext.dictpersistence

View file

@ -25,5 +25,8 @@ A basic example of an [inline bot](https://core.telegram.org/bots/inline). Don't
### [`paymentbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/paymentbot.py)
A basic example of a bot that can accept payments. Don't forget to enable and configure payments with [@BotFather](https://telegram.me/BotFather).
### [`persistentconversationbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/persistentconversationbot.py)
A basic example of a bot store conversation state and user_data over multiple restarts.
## Pure API
The [`echobot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/echobot.py) example uses only the pure, "bare-metal" API wrapper.

View file

@ -0,0 +1,170 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Simple Bot to reply to Telegram messages
# This program is dedicated to the public domain under the CC0 license.
"""
This Bot uses the Updater class to handle the bot.
First, a few callback functions are defined. Then, those functions are passed to
the Dispatcher and registered at their respective places.
Then, the bot is started and runs until we press Ctrl-C on the command line.
Usage:
Example of a bot-user conversation using ConversationHandler.
Send /start to initiate the conversation.
Press Ctrl-C on the command line or send a signal to the process to stop the
bot.
"""
from telegram import ReplyKeyboardMarkup
from telegram.ext import (Updater, CommandHandler, MessageHandler, Filters, RegexHandler,
ConversationHandler, PicklePersistence)
import logging
# Enable logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO)
logger = logging.getLogger(__name__)
CHOOSING, TYPING_REPLY, TYPING_CHOICE = range(3)
reply_keyboard = [['Age', 'Favourite colour'],
['Number of siblings', 'Something else...'],
['Done']]
markup = ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)
def facts_to_str(user_data):
facts = list()
for key, value in user_data.items():
facts.append('{} - {}'.format(key, value))
return "\n".join(facts).join(['\n', '\n'])
def start(bot, update, user_data):
reply_text = "Hi! My name is Doctor Botter."
if user_data:
reply_text += " You already told me your {}. Why don't you tell me something more " \
"about yourself? Or change enything I " \
"already know.".format(", ".join(user_data.keys()))
else:
reply_text += " I will hold a more complex conversation with you. Why don't you tell me " \
"something about yourself?"
update.message.reply_text(reply_text, reply_markup=markup)
return CHOOSING
def regular_choice(bot, update, user_data):
text = update.message.text
user_data['choice'] = text
if user_data.get(text):
reply_text = 'Your {}, I already know the following ' \
'about that: {}'.format(text.lower(), user_data[text.lower()])
else:
reply_text = 'Your {}? Yes, I would love to hear about that!'.format(text.lower())
update.message.reply_text(reply_text)
return TYPING_REPLY
def custom_choice(bot, update):
update.message.reply_text('Alright, please send me the category first, '
'for example "Most impressive skill"')
return TYPING_CHOICE
def received_information(bot, update, user_data):
text = update.message.text
category = user_data['choice']
user_data[category] = text.lower()
del user_data['choice']
update.message.reply_text("Neat! Just so you know, this is what you already told me:"
"{}"
"You can tell me more, or change your opinion on "
"something.".format(facts_to_str(user_data)), reply_markup=markup)
return CHOOSING
def show_data(bot, update, user_data):
update.message.reply_text("This is what you already told me:"
"{}".format(facts_to_str(user_data)))
def done(bot, update, user_data):
if 'choice' in user_data:
del user_data['choice']
update.message.reply_text("I learned these facts about you:"
"{}"
"Until next time!".format(facts_to_str(user_data)))
return ConversationHandler.END
def error(bot, update, error):
"""Log Errors caused by Updates."""
logger.warning('Update "%s" caused error "%s"', update, error)
def main():
# Create the Updater and pass it your bot's token.
pp = PicklePersistence(filename='conversationbot')
updater = Updater("TOKEN", persistence=pp)
# Get the dispatcher to register handlers
dp = updater.dispatcher
# Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY
conv_handler = ConversationHandler(
entry_points=[CommandHandler('start', start, pass_user_data=True)],
states={
CHOOSING: [RegexHandler('^(Age|Favourite colour|Number of siblings)$',
regular_choice,
pass_user_data=True),
RegexHandler('^Something else...$',
custom_choice),
],
TYPING_CHOICE: [MessageHandler(Filters.text,
regular_choice,
pass_user_data=True),
],
TYPING_REPLY: [MessageHandler(Filters.text,
received_information,
pass_user_data=True),
],
},
fallbacks=[RegexHandler('^Done$', done, pass_user_data=True)],
name="my_conversation",
persistent=True
)
dp.add_handler(conv_handler)
show_data_handler = CommandHandler('show_data', show_data, pass_user_data=True)
dp.add_handler(show_data_handler)
# log all errors
dp.add_error_handler(error)
# Start the Bot
updater.start_polling()
# Run the bot until you press Ctrl-C or the process receives SIGINT,
# SIGTERM or SIGABRT. This should be used most of the time, since
# start_polling() is non-blocking and will stop the bot gracefully.
updater.idle()
if __name__ == '__main__':
main()

View file

@ -18,6 +18,9 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""Extensions over the Telegram Bot API to facilitate bot making"""
from .basepersistence import BasePersistence
from .picklepersistence import PicklePersistence
from .dictpersistence import DictPersistence
from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async
from .jobqueue import JobQueue, Job
from .updater import Updater
@ -43,4 +46,5 @@ __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler',
'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler',
'StringRegexHandler', 'TypeHandler', 'ConversationHandler',
'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue',
'DispatcherHandlerStop', 'run_async')
'DispatcherHandlerStop', 'run_async', 'BasePersistence', 'PicklePersistence',
'DictPersistence')

View file

@ -0,0 +1,123 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""
class BasePersistence(object):
"""Interface class for adding persistence to your bot.
Subclass this object for different implementations of a persistent bot.
All relevant methods must be overwritten. This means:
* If :attr:`store_chat_data` is ``True`` you must overwrite :meth:`get_chat_data` and
:meth:`update_chat_data`.
* If :attr:`store_user_data` is ``True`` you must overwrite :meth:`get_user_data` and
:meth:`update_user_data`.
* If you want to store conversation data with :class:`telegram.ext.ConversationHandler`, you
must overwrite :meth:`get_conversations` and :meth:`update_conversation`.
* :meth:`flush` will be called when the bot is shutdown.
Attributes:
store_user_data (:obj:`bool`): Optional, Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Optional. Whether chat_data should be saved by this
persistence class.
Args:
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether chat_data should be saved by this
persistence class. Default is ``True`` .
"""
def __init__(self, store_user_data=True, store_chat_data=True):
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
def get_user_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the user_data if stored, or an empty
``defaultdict(dict)``.
Returns:
:obj:`defaultdict`: The restored user data.
"""
raise NotImplementedError
def get_chat_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the chat_data if stored, or an empty
``defaultdict(dict)``.
Returns:
:obj:`defaultdict`: The restored chat data.
"""
raise NotImplementedError
def get_conversations(self, name):
""""Will be called by :class:`telegram.ext.Dispatcher` when a
:class:`telegram.ext.ConversationHandler` is added if
:attr:`telegram.ext.ConversationHandler.persistent` is ``True``.
It should return the conversations for the handler with `name` or an empty ``dict``
Args:
name (:obj:`str`): The handlers name.
Returns:
:obj:`dict`: The restored conversations for the handler.
"""
raise NotImplementedError
def update_conversation(self, name, key, new_state):
"""Will be called when a :attr:`telegram.ext.ConversationHandler.update_state`
is called. this allows the storeage of the new state in the persistence.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
raise NotImplementedError
def update_user_data(self, user_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
raise NotImplementedError
def update_chat_data(self, chat_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[user_id].
"""
raise NotImplementedError
def flush(self):
"""Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the
persistence a chance to finish up saving or close a database connection gracefully. If this
is not of any importance just pass will be sufficient.
"""
pass

View file

@ -79,6 +79,10 @@ class ConversationHandler(Handler):
conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`): Optional. When this handler
is inactive more than this timeout (in seconds), it will be automatically ended. If
this value is 0 (default), there will be no timeout.
name (:obj:`str`): Optional. The name for this conversationhandler. Required for
persistence
persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be
saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater`
Args:
entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can
@ -113,6 +117,10 @@ class ConversationHandler(Handler):
conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`, optional): When this handler
is inactive more than this timeout (in seconds), it will be automatically ended. If
this value is 0 or None (default), there will be no timeout.
name (:obj:`str`, optional): The name for this conversationhandler. Required for
persistence
persistent (:obj:`bool`, optional): If the conversations dict for this handler should be
saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater`
Raises:
ValueError
@ -131,7 +139,9 @@ class ConversationHandler(Handler):
per_chat=True,
per_user=True,
per_message=False,
conversation_timeout=None):
conversation_timeout=None,
name=None,
persistent=False):
self.entry_points = entry_points
self.states = states
@ -144,6 +154,13 @@ class ConversationHandler(Handler):
self.per_chat = per_chat
self.per_message = per_message
self.conversation_timeout = conversation_timeout
self.name = name
if persistent and not self.name:
raise ValueError("Conversations can't be persistent when handler is unnamed.")
self.persistent = persistent
self.persistence = None
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
Set by dispatcher"""
self.timeout_jobs = dict()
self.conversations = dict()
@ -318,14 +335,21 @@ class ConversationHandler(Handler):
if new_state == self.END:
if key in self.conversations:
del self.conversations[key]
if self.persistent:
self.persistence.update_conversation(self.name, key, None)
else:
pass
elif isinstance(new_state, Promise):
self.conversations[key] = (self.conversations.get(key), new_state)
if self.persistent:
self.persistence.update_conversation(self.name, key,
(self.conversations.get(key), new_state))
elif new_state is not None:
self.conversations[key] = new_state
if self.persistent:
self.persistence.update_conversation(self.name, key, new_state)
def _trigger_timeout(self, bot, job):
del self.timeout_jobs[job.context]

View file

@ -0,0 +1,194 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the DictPersistence class."""
from telegram.utils.helpers import decode_user_chat_data_from_json,\
decode_conversations_from_json, enocde_conversations_to_json
try:
import ujson as json
except ImportError:
import json
from collections import defaultdict
from telegram.ext import BasePersistence
class DictPersistence(BasePersistence):
"""Using python's dicts and json for making you bot persistent.
Attributes:
store_user_data (:obj:`bool`): Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Whether chat_data should be saved by this
persistence class.
Args:
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
user_data on creating this persistence. Default is ``""``.
chat_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
chat_data on creating this persistence. Default is ``""``.
conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct
conversation on creating this persistence. Default is ``""``.
"""
def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='',
chat_data_json='', conversations_json=''):
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self._user_data = None
self._chat_data = None
self._conversations = None
self._user_data_json = None
self._chat_data_json = None
self._conversations_json = None
if user_data_json:
try:
self._user_data = decode_user_chat_data_from_json(user_data_json)
self._user_data_json = user_data_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize user_data_json. Not valid JSON")
if chat_data_json:
try:
self._chat_data = decode_user_chat_data_from_json(chat_data_json)
self._chat_data_json = chat_data_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize chat_data_json. Not valid JSON")
if conversations_json:
try:
self._conversations = decode_conversations_from_json(conversations_json)
self._conversations_json = conversations_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize conversations_json. Not valid JSON")
@property
def user_data(self):
""":obj:`dict`: The user_data as a dict"""
return self._user_data
@property
def user_data_json(self):
""":obj:`str`: The user_data serialized as a JSON-string."""
if self._user_data_json:
return self._user_data_json
else:
return json.dumps(self.user_data)
@property
def chat_data(self):
""":obj:`dict`: The chat_data as a dict"""
return self._chat_data
@property
def chat_data_json(self):
""":obj:`str`: The chat_data serialized as a JSON-string."""
if self._chat_data_json:
return self._chat_data_json
else:
return json.dumps(self.chat_data)
@property
def conversations(self):
""":obj:`dict`: The conversations as a dict"""
return self._conversations
@property
def conversations_json(self):
""":obj:`str`: The conversations serialized as a JSON-string."""
if self._conversations_json:
return self._conversations_json
else:
return enocde_conversations_to_json(self.conversations)
def get_user_data(self):
"""Returns the user_data created from the ``user_data_json`` or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.user_data:
pass
else:
self._user_data = defaultdict(dict)
return self.user_data.copy()
def get_chat_data(self):
"""Returns the chat_data created from the ``chat_data_json`` or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.chat_data:
pass
else:
self._chat_data = defaultdict(dict)
return self.chat_data.copy()
def get_conversations(self, name):
"""Returns the conversations created from the ``conversations_json`` or an empty
defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.conversations:
pass
else:
self._conversations = {}
return self.conversations.get(name, {}).copy()
def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self._conversations.setdefault(name, {}).get(key) == new_state:
return
self._conversations[name][key] = new_state
self._conversations_json = None
def update_user_data(self, user_id, data):
"""Will update the user_data (if changed).
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self._user_data.get(user_id) == data:
return
self._user_data[user_id] = data
self._user_data_json = None
def update_chat_data(self, chat_id, data):
"""Will update the chat_data (if changed).
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self._chat_data.get(chat_id) == data:
return
self._chat_data[chat_id] = data
self._chat_data_json = None

View file

@ -33,6 +33,7 @@ from future.builtins import range
from telegram import TelegramError
from telegram.ext.handler import Handler
from telegram.utils.promise import Promise
from telegram.ext import BasePersistence
logging.getLogger(__name__).addHandler(logging.NullHandler())
DEFAULT_GROUP = 0
@ -48,6 +49,7 @@ def run_async(func):
Note: Use this decorator to run handlers asynchronously.
"""
@wraps(func)
def async_func(*args, **kwargs):
return Dispatcher.get_instance().run_async(func, *args, **kwargs)
@ -70,6 +72,10 @@ class Dispatcher(object):
instance to pass onto handler callbacks.
workers (:obj:`int`): Number of maximum concurrent worker threads for the ``@run_async``
decorator.
user_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the user.
chat_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the chat.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts
Args:
bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers.
@ -78,6 +84,8 @@ class Dispatcher(object):
instance to pass onto handler callbacks.
workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the
``@run_async`` decorator. defaults to 4.
persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to
store data that should be persistent over restarts
"""
@ -86,16 +94,30 @@ class Dispatcher(object):
__singleton = None
logger = logging.getLogger(__name__)
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None):
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None,
persistence=None):
self.bot = bot
self.update_queue = update_queue
self.job_queue = job_queue
self.workers = workers
self.user_data = defaultdict(dict)
""":obj:`dict`: A dictionary handlers can use to store data for the user."""
self.chat_data = defaultdict(dict)
""":obj:`dict`: A dictionary handlers can use to store data for the chat."""
if persistence:
if not isinstance(persistence, BasePersistence):
raise TypeError("persistence should be based on telegram.ext.BasePersistence")
self.persistence = persistence
if self.persistence.store_user_data:
self.user_data = self.persistence.get_user_data()
if not isinstance(self.user_data, defaultdict):
raise ValueError("user_data must be of type defaultdict")
if self.persistence.store_chat_data:
self.chat_data = self.persistence.get_chat_data()
if not isinstance(self.chat_data, defaultdict):
raise ValueError("chat_data must be of type defaultdict")
else:
self.persistence = None
self.job_queue = job_queue
self.handlers = {}
"""Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]: Holds the handlers per group."""
self.groups = []
@ -277,6 +299,19 @@ class Dispatcher(object):
try:
for handler in (x for x in self.handlers[group] if x.check_update(update)):
handler.handle_update(update, self)
if self.persistence:
if self.persistence.store_chat_data and update.effective_chat.id:
chat_id = update.effective_chat.id
try:
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
except Exception:
self.logger.exception('Saving chat data raised an error')
if self.persistence.store_user_data and update.effective_user.id:
user_id = update.effective_user.id
try:
self.persistence.update_user_data(user_id, self.user_data[user_id])
except Exception:
self.logger.exception('Saving user data raised an error')
break
# Stop processing with any other handler.
@ -324,11 +359,20 @@ class Dispatcher(object):
group (:obj:`int`, optional): The group identifier. Default is 0.
"""
# Unfortunately due to circular imports this has to be here
from .conversationhandler import ConversationHandler
if not isinstance(handler, Handler):
raise TypeError('handler is not an instance of {0}'.format(Handler.__name__))
if not isinstance(group, int):
raise TypeError('group is not int')
if isinstance(handler, ConversationHandler) and handler.persistent:
if not self.persistence:
raise ValueError(
"Conversationhandler {} can not be persistent if dispatcher has no "
"persistence".format(handler.name))
handler.conversations = self.persistence.get_conversations(handler.name)
handler.persistence = self.persistence
if group not in self.handlers:
self.handlers[group] = list()

View file

@ -0,0 +1,235 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the PicklePersistence class."""
import pickle
from collections import defaultdict
from telegram.ext import BasePersistence
class PicklePersistence(BasePersistence):
"""Using python's builtin pickle for making you bot persistent.
Attributes:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
is false this will be used as a prefix.
store_user_data (:obj:`bool`): Optional. Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Optional. Whether user_data should be saved by this
persistence class.
single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of
`filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is
``True``.
on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
Args:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
is false this will be used as a prefix.
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
single_file (:obj:`bool`, optional): When ``False`` will store 3 sperate files of
`filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is
``True``.
on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
"""
def __init__(self, filename, store_user_data=True, store_chat_data=True, singe_file=True,
on_flush=False):
self.filename = filename
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self.single_file = singe_file
self.on_flush = on_flush
self.user_data = None
self.chat_data = None
self.conversations = None
def load_singlefile(self):
try:
filename = self.filename
with open(self.filename, "rb") as f:
all = pickle.load(f)
self.user_data = defaultdict(dict, all['user_data'])
self.chat_data = defaultdict(dict, all['chat_data'])
self.conversations = all['conversations']
except IOError:
self.conversations = {}
self.user_data = defaultdict(dict)
self.chat_data = defaultdict(dict)
except pickle.UnpicklingError:
raise TypeError("File {} does not contain valid pickle data".format(filename))
except Exception:
raise TypeError("Something went wrong unpickling {}".format(filename))
def load_file(self, filename):
try:
with open(filename, "rb") as f:
return pickle.load(f)
except IOError:
return None
except pickle.UnpicklingError:
raise TypeError("File {} does not contain valid pickle data".format(filename))
except Exception:
raise TypeError("Something went wrong unpickling {}".format(filename))
def dump_singlefile(self):
with open(self.filename, "wb") as f:
all = {'conversations': self.conversations, 'user_data': self.user_data,
'chat_data': self.chat_data}
pickle.dump(all, f)
def dump_file(self, filename, data):
with open(filename, "wb") as f:
pickle.dump(data, f)
def get_user_data(self):
"""Returns the user_data from the pickle file if it exsists or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.user_data:
pass
elif not self.single_file:
filename = "{}_user_data".format(self.filename)
data = self.load_file(filename)
if not data:
data = defaultdict(dict)
else:
data = defaultdict(dict, data)
self.user_data = data
else:
self.load_singlefile()
return self.user_data.copy()
def get_chat_data(self):
"""Returns the chat_data from the pickle file if it exsists or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored chat data.
"""
if self.chat_data:
pass
elif not self.single_file:
filename = "{}_chat_data".format(self.filename)
data = self.load_file(filename)
if not data:
data = defaultdict(dict)
else:
data = defaultdict(dict, data)
self.chat_data = data
else:
self.load_singlefile()
return self.chat_data.copy()
def get_conversations(self, name):
"""Returns the conversations from the pickle file if it exsists or an empty defaultdict.
Args:
name (:obj:`str`): The handlers name.
Returns:
:obj:`dict`: The restored conversations for the handler.
"""
if self.conversations:
pass
elif not self.single_file:
filename = "{}_conversations".format(self.filename)
data = self.load_file(filename)
if not data:
data = {name: {}}
self.conversations = data
else:
self.load_singlefile()
return self.conversations.get(name, {}).copy()
def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler and depending on :attr:`on_flush`
save the pickle file.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self.conversations.setdefault(name, {}).get(key) == new_state:
return
self.conversations[name][key] = new_state
if not self.on_flush:
if not self.single_file:
filename = "{}_conversations".format(self.filename)
self.dump_file(filename, self.conversations)
else:
self.dump_singlefile()
def update_user_data(self, user_id, data):
"""Will update the user_data (if changed) and depending on :attr:`on_flush` save the
pickle file.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_user_data".format(self.filename)
self.dump_file(filename, self.user_data)
else:
self.dump_singlefile()
def update_chat_data(self, chat_id, data):
"""Will update the chat_data (if changed) and depending on :attr:`on_flush` save the
pickle file.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_chat_data".format(self.filename)
self.dump_file(filename, self.chat_data)
else:
self.dump_singlefile()
def flush(self):
"""If :attr:`on_flush` is set to ``True``. Will save all data in memory to pickle file(s). If
it's ``False`` will just pass.
"""
if not self.on_flush:
pass
else:
if self.single_file:
self.dump_singlefile()
else:
self.dump_file("{}_user_data".format(self.filename), self.user_data)
self.dump_file("{}_chat_data".format(self.filename), self.chat_data)
self.dump_file("{}_conversations".format(self.filename), self.conversations)

View file

@ -55,6 +55,8 @@ class Updater(object):
dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that handles the updates and
dispatches them to the handlers.
running (:obj:`bool`): Indicates if the updater is running.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts.
Args:
token (:obj:`str`, optional): The bot's token given by the @BotFather.
@ -73,6 +75,8 @@ class Updater(object):
`telegram.utils.request.Request` object (ignored if `bot` argument is used). The
request_kwargs are very useful for the advanced users who would like to control the
default timeouts and/or control the proxy used for http communication.
persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to
store data that should be persistent over restarts.
Note:
You must supply either a :attr:`bot` or a :attr:`token` argument.
@ -92,7 +96,8 @@ class Updater(object):
private_key=None,
private_key_password=None,
user_sig_handler=None,
request_kwargs=None):
request_kwargs=None,
persistence=None):
if (token is None) and (bot is None):
raise ValueError('`token` or `bot` must be passed')
@ -129,12 +134,14 @@ class Updater(object):
self.update_queue = Queue()
self.job_queue = JobQueue(self.bot)
self.__exception_event = Event()
self.persistence = persistence
self.dispatcher = Dispatcher(
self.bot,
self.update_queue,
job_queue=self.job_queue,
workers=workers,
exception_event=self.__exception_event)
exception_event=self.__exception_event,
persistence=self.persistence)
self.last_update_id = 0
self.running = False
self.is_idle = False
@ -485,6 +492,8 @@ class Updater(object):
if self.running:
self.logger.info('Received signal {} ({}), stopping...'.format(
signum, get_signal_name(signum)))
if self.persistence:
self.persistence.flush()
self.stop()
if self.user_sig_handler:
self.user_sig_handler(signum, frame)

View file

@ -17,6 +17,12 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains helper functions."""
from collections import defaultdict
try:
import ujson as json
except ImportError:
import json
from html import escape
import re
@ -139,3 +145,65 @@ def effective_message_type(entity):
return i
return None
def enocde_conversations_to_json(conversations):
"""Helper method to encode a conversations dict (that uses tuples as keys) to a
JSON-serializable way. Use :attr:`_decode_conversations_from_json` to decode.
Args:
conversations (:obj:`dict`): The conversations dict to transofrm to JSON.
Returns:
:obj:`str`: The JSON-serialized conversations dict
"""
tmp = {}
for handler, states in conversations.items():
tmp[handler] = {}
for key, state in states.items():
tmp[handler][json.dumps(key)] = state
return json.dumps(tmp)
def decode_conversations_from_json(json_string):
"""Helper method to decode a conversations dict (that uses tuples as keys) from a
JSON-string created with :attr:`_encode_conversations_to_json`.
Args:
json_string (:obj:`str`): The conversations dict as JSON string.
Returns:
:obj:`dict`: The conversations dict after decoding
"""
tmp = json.loads(json_string)
conversations = {}
for handler, states in tmp.items():
conversations[handler] = {}
for key, state in states.items():
conversations[handler][tuple(json.loads(key))] = state
return conversations
def decode_user_chat_data_from_json(data):
"""Helper method to decode chat or user data (that uses ints as keys) from a
JSON-string.
Args:
data (:obj:`str`): The user/chat_data dict as JSON string.
Returns:
:obj:`dict`: The user/chat_data defaultdict after decoding
"""
tmp = defaultdict(dict)
decoded_data = json.loads(data)
for user, data in decoded_data.items():
user = int(user)
tmp[user] = {}
for key, value in data.items():
try:
key = int(key)
except ValueError:
pass
tmp[user][key] = value
return tmp

View file

@ -96,6 +96,12 @@ class TestConversationHandler(object):
ConversationHandler(self.entry_points, self.states, self.fallbacks,
per_chat=False, per_user=False, per_message=False)
def test_name_and_persistent(self, dp):
with pytest.raises(ValueError, match="when handler is unnamed"):
dp.add_handler(ConversationHandler([], {}, [], persistent=True))
c = ConversationHandler([], {}, [], name="handler", persistent=True)
assert c.name == "handler"
def test_conversation_handler(self, dp, bot, user1, user2):
handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
fallbacks=self.fallbacks)

View file

@ -82,6 +82,16 @@ class TestDispatcher(object):
sleep(.1)
assert self.received is None
def test_construction_with_bad_persistence(self, caplog, bot):
class my_per:
def __init__(self):
self.store_user_data = False
self.store_chat_data = False
with pytest.raises(TypeError,
match='persistence should be based on telegram.ext.BasePersistence'):
Dispatcher(bot, None, persistence=my_per())
def test_error_handler_that_raises_errors(self, dp):
"""
Make sure that errors raised in error handlers don't break the main loop of the dispatcher

751
tests/test_persistence.py Normal file
View file

@ -0,0 +1,751 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
from telegram.utils.helpers import enocde_conversations_to_json
try:
import ujson as json
except ImportError:
import json
import logging
import os
import pickle
from collections import defaultdict
import pytest
from telegram import Update, Message, User, Chat
from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \
PicklePersistence, CommandHandler, DictPersistence
@pytest.fixture(scope="function")
def base_persistence():
return BasePersistence(store_chat_data=True, store_user_data=True)
@pytest.fixture(scope="function")
def chat_data():
return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}})
@pytest.fixture(scope="function")
def user_data():
return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}})
@pytest.fixture(scope='function')
def conversations():
return {'name1': {(123, 123): 3, (456, 654): 4},
'name2': {(123, 321): 1, (890, 890): 2}}
@pytest.fixture(scope="function")
def updater(bot, base_persistence):
base_persistence.store_chat_data = False
base_persistence.store_user_data = False
u = Updater(bot=bot, persistence=base_persistence)
base_persistence.store_chat_data = True
base_persistence.store_user_data = True
return u
class TestBasePersistence(object):
def test_creation(self, base_persistence):
assert base_persistence.store_chat_data
assert base_persistence.store_user_data
with pytest.raises(NotImplementedError):
base_persistence.get_chat_data()
with pytest.raises(NotImplementedError):
base_persistence.get_user_data()
with pytest.raises(NotImplementedError):
base_persistence.get_conversations("test")
with pytest.raises(NotImplementedError):
base_persistence.update_chat_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_user_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_conversation(None, None, None)
def test_implementation(self, updater, base_persistence):
dp = updater.dispatcher
assert dp.persistence == base_persistence
def test_conversationhandler_addition(self, dp, base_persistence):
with pytest.raises(ValueError, match="when handler is unnamed"):
ConversationHandler([], [], [], persistent=True)
with pytest.raises(ValueError, match="if dispatcher has no persistence"):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
dp.persistence = base_persistence
with pytest.raises(NotImplementedError):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data):
def get_user_data():
return "test"
def get_chat_data():
return "test"
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
with pytest.raises(ValueError, match="user_data must be of type defaultdict"):
u = Updater(bot=bot, persistence=base_persistence)
def get_user_data():
return user_data
base_persistence.get_user_data = get_user_data
with pytest.raises(ValueError, match="chat_data must be of type defaultdict"):
u = Updater(bot=bot, persistence=base_persistence)
def get_chat_data():
return chat_data
base_persistence.get_chat_data = get_chat_data
u = Updater(bot=bot, persistence=base_persistence)
assert u.dispatcher.chat_data == chat_data
assert u.dispatcher.user_data == user_data
u.dispatcher.chat_data[442233]['test5'] = 'test6'
assert u.dispatcher.chat_data[442233]['test5'] == 'test6'
def test_dispatcher_integration_handlers(self, caplog, bot, base_persistence,
chat_data, user_data):
def get_user_data():
return user_data
def get_chat_data():
return chat_data
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
# base_persistence.update_chat_data = lambda x: x
# base_persistence.update_user_data = lambda x: x
updater = Updater(bot=bot, persistence=base_persistence)
dp = updater.dispatcher
def callback_known_user(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail('user_data corrupt')
def callback_known_chat(bot, update, user_data, chat_data):
if not chat_data['test3'] == 'test4':
pytest.fail('chat_data corrupt')
def callback_unknown_user_or_chat(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail('user_data corrupt')
if not chat_data == {}:
pytest.fail('chat_data corrupt')
user_data[1] = 'test7'
chat_data[2] = 'test8'
known_user = MessageHandler(Filters.user(user_id=12345), callback_known_user,
pass_chat_data=True, pass_user_data=True)
known_chat = MessageHandler(Filters.chat(chat_id=-67890), callback_known_chat,
pass_chat_data=True, pass_user_data=True)
unknown = MessageHandler(Filters.all, callback_unknown_user_or_chat, pass_chat_data=True,
pass_user_data=True)
dp.add_handler(known_user)
dp.add_handler(known_chat)
dp.add_handler(unknown)
user1 = User(id=12345, first_name='test user', is_bot=False)
user2 = User(id=54321, first_name='test user', is_bot=False)
chat1 = Chat(id=-67890, type='group')
chat2 = Chat(id=-987654, type='group')
m = Message(1, user1, None, chat2)
u = Update(0, m)
with caplog.at_level(logging.ERROR):
dp.process_update(u)
rec = caplog.records[-1]
assert rec.msg == 'Saving user data raised an error'
assert rec.levelname == 'ERROR'
rec = caplog.records[-2]
assert rec.msg == 'Saving chat data raised an error'
assert rec.levelname == 'ERROR'
m.from_user = user2
m.chat = chat1
u = Update(1, m)
dp.process_update(u)
m.chat = chat2
u = Update(2, m)
def save_chat_data(data):
if -987654 not in data:
pytest.fail()
def save_user_data(data):
if 54321 not in data:
pytest.fail()
base_persistence.update_chat_data = save_chat_data
base_persistence.update_user_data = save_user_data
dp.process_update(u)
assert dp.user_data[54321][1] == 'test7'
assert dp.chat_data[-987654][2] == 'test8'
@pytest.fixture(scope='function')
def pickle_persistence():
return PicklePersistence(filename='pickletest',
store_user_data=True,
store_chat_data=True,
singe_file=False,
on_flush=False)
@pytest.fixture(scope='function')
def bad_pickle_files():
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
with open(name, 'w') as f:
f.write('(())')
yield True
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
os.remove(name)
@pytest.fixture(scope='function')
def good_pickle_files(user_data, chat_data, conversations):
all = {'user_data': user_data, 'chat_data': chat_data, 'conversations': conversations}
with open('pickletest_user_data', 'wb') as f:
pickle.dump(user_data, f)
with open('pickletest_chat_data', 'wb') as f:
pickle.dump(chat_data, f)
with open('pickletest_conversations', 'wb') as f:
pickle.dump(conversations, f)
with open('pickletest', 'wb') as f:
pickle.dump(all, f)
yield True
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
os.remove(name)
@pytest.fixture(scope='function')
def update(bot):
user = User(id=321, first_name='test_user', is_bot=False)
chat = Chat(id=123, type='group')
message = Message(1, user, None, chat, text="Hi there", bot=bot)
return Update(0, message=message)
class TestPickelPersistence(object):
def test_no_files_present_multi_file(self, pickle_persistence):
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_conversations('noname') == {}
assert pickle_persistence.get_conversations('noname') == {}
def test_no_files_present_single_file(self, pickle_persistence):
pickle_persistence.single_file = True
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_conversations('noname') == {}
def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files):
with pytest.raises(TypeError, match='pickletest_user_data'):
pickle_persistence.get_user_data()
with pytest.raises(TypeError, match='pickletest_chat_data'):
pickle_persistence.get_chat_data()
with pytest.raises(TypeError, match='pickletest_conversations'):
pickle_persistence.get_conversations('name')
def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files):
pickle_persistence.single_file = True
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_user_data()
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_chat_data()
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_conversations('name')
def test_with_good_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = pickle_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = pickle_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = pickle_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_with_good_single_file(self, pickle_persistence, good_pickle_files):
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = pickle_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = pickle_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = pickle_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_updating_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert conversations_test['name1'] == conversation1
def test_updating_single_file(self, pickle_persistence, good_pickle_files):
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert conversations_test['name1'] == conversation1
def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files):
# Should run without error
pickle_persistence.flush()
pickle_persistence.on_flush = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert not user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert not chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert not conversations_test['name1'] == conversation1
pickle_persistence.flush()
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert conversations_test['name1'] == conversation1
def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files):
# Should run without error
pickle_persistence.flush()
pickle_persistence.on_flush = True
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert not user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert not chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert not conversations_test['name1'] == conversation1
pickle_persistence.flush()
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert conversations_test['name1'] == conversation1
def test_with_handler(self, bot, update, pickle_persistence, good_pickle_files):
u = Updater(bot=bot, persistence=pickle_persistence)
dp = u.dispatcher
def first(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail()
if not chat_data == {}:
pytest.fail()
user_data['test1'] = 'test2'
chat_data['test3'] = 'test4'
def second(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail()
if not chat_data['test3'] == 'test4':
pytest.fail()
h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True)
h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True)
dp.add_handler(h1)
dp.process_update(update)
del (dp)
del (u)
del (pickle_persistence)
pickle_persistence_2 = PicklePersistence(filename='pickletest',
store_user_data=True,
store_chat_data=True,
singe_file=False,
on_flush=False)
u = Updater(bot=bot, persistence=pickle_persistence_2)
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)
def test_with_conversationHandler(self, dp, update, good_pickle_files, pickle_persistence):
dp.persistence = pickle_persistence
NEXT, NEXT2 = range(2)
def start(bot, update):
return NEXT
start = CommandHandler('start', start)
def next(bot, update):
return NEXT2
next = MessageHandler(None, next)
def next2(bot, update):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
update.message.text = '/start'
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == pickle_persistence.conversations['name2']
@classmethod
def teardown_class(cls):
try:
for name in ['pickletest_user_data', 'pickletest_chat_data',
'pickletest_conversations',
'pickletest']:
os.remove(name)
except Exception:
pass
@pytest.fixture(scope='function')
def user_data_json(user_data):
return json.dumps(user_data)
@pytest.fixture(scope='function')
def chat_data_json(chat_data):
return json.dumps(chat_data)
@pytest.fixture(scope='function')
def conversations_json(conversations):
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
{"[123, 321]": 1, "[890, 890]": 2}}"""
class TestDictPersistence(object):
def test_no_json_given(self):
dict_persistence = DictPersistence()
assert dict_persistence.get_user_data() == defaultdict(dict)
assert dict_persistence.get_chat_data() == defaultdict(dict)
assert dict_persistence.get_conversations('noname') == {}
def test_bad_json_string_given(self):
bad_user_data = 'thisisnojson99900()))('
bad_chat_data = 'thisisnojson99900()))('
bad_conversations = 'thisisnojson99900()))('
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files):
bad_user_data = '["this", "is", "json"]'
bad_chat_data = '["this", "is", "json"]'
bad_conversations = '["this", "is", "json"]'
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
def test_good_json_input(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
user_data = dict_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = dict_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = dict_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = dict_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json,
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
assert dict_persistence.user_data == user_data
assert dict_persistence.chat_data == chat_data
assert dict_persistence.conversations == conversations
def test_json_outputs(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
assert dict_persistence.user_data_json == user_data_json
assert dict_persistence.chat_data_json == chat_data_json
assert dict_persistence.conversations_json == conversations_json
def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json,
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
user_data_two = user_data.copy()
user_data_two.update({4: {5: 6}})
dict_persistence.update_user_data(4, {5: 6})
assert dict_persistence.user_data == user_data_two
assert dict_persistence.user_data_json != user_data_json
assert dict_persistence.user_data_json == json.dumps(user_data_two)
chat_data_two = chat_data.copy()
chat_data_two.update({7: {8: 9}})
dict_persistence.update_chat_data(7, {8: 9})
assert dict_persistence.chat_data == chat_data_two
assert dict_persistence.chat_data_json != chat_data_json
assert dict_persistence.chat_data_json == json.dumps(chat_data_two)
conversations_two = conversations.copy()
conversations_two.update({'name3': {(1, 2): 3}})
dict_persistence.update_conversation('name3', (1, 2), 3)
assert dict_persistence.conversations == conversations_two
assert dict_persistence.conversations_json != conversations_json
assert dict_persistence.conversations_json == enocde_conversations_to_json(
conversations_two)
def test_with_handler(self, bot, update):
dict_persistence = DictPersistence()
u = Updater(bot=bot, persistence=dict_persistence)
dp = u.dispatcher
def first(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail()
if not chat_data == {}:
pytest.fail()
user_data['test1'] = 'test2'
chat_data[3] = 'test4'
def second(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail()
if not chat_data[3] == 'test4':
pytest.fail()
h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True)
h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True)
dp.add_handler(h1)
dp.process_update(update)
del (dp)
del (u)
user_data = dict_persistence.user_data_json
chat_data = dict_persistence.chat_data_json
del (dict_persistence)
dict_persistence_2 = DictPersistence(user_data_json=user_data,
chat_data_json=chat_data)
u = Updater(bot=bot, persistence=dict_persistence_2)
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)
def test_with_conversationHandler(self, dp, update, conversations_json):
dict_persistence = DictPersistence(conversations_json=conversations_json)
dp.persistence = dict_persistence
NEXT, NEXT2 = range(2)
def start(bot, update):
return NEXT
start = CommandHandler('start', start)
def next(bot, update):
return NEXT2
next = MessageHandler(None, next)
def next2(bot, update):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
update.message.text = '/start'
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == dict_persistence.conversations['name2']