Persistence (#1017)

* BasePersistence

* basic construct

* Keep working

* Continue work

Add tests for Basepersistence

* Finish up BasePersistence and implementation

* PickelPersistence and start tests

* Finishing up

* Oops, left in some typings

* Compatibilty issues regarding py2 solved

For Py2 compatibility

* increasing coverage

* Small changes due to CR

* All persistence tests in one file

* add DictPersistence

* Last changes per CR

* forgot change

* changes per CR

* call update_* only with relevant data

As discussed with @jsmnbom

* Add conversationbot Example

* should not have committed API-key
This commit is contained in:
Eldinnie 2018-09-20 22:50:40 +02:00 committed by Jasmin Bom
parent b9f56ca479
commit 439790375e
17 changed files with 1677 additions and 9 deletions

View file

@ -0,0 +1,6 @@
telegram.ext.BasePersistence
============================
.. autoclass:: telegram.ext.BasePersistence
:members:
:show-inheritance:

View file

@ -0,0 +1,6 @@
telegram.ext.DictPersistence
============================
.. autoclass:: telegram.ext.DictPersistence
:members:
:show-inheritance:

View file

@ -0,0 +1,6 @@
telegram.ext.PicklePersistence
==============================
.. autoclass:: telegram.ext.PicklePersistence
:members:
:show-inheritance:

View file

@ -29,3 +29,12 @@ Handlers
telegram.ext.stringcommandhandler telegram.ext.stringcommandhandler
telegram.ext.stringregexhandler telegram.ext.stringregexhandler
telegram.ext.typehandler telegram.ext.typehandler
Persistence
-----------
.. toctree::
telegram.ext.basepersistence
telegram.ext.picklepersistence
telegram.ext.dictpersistence

View file

@ -25,5 +25,8 @@ A basic example of an [inline bot](https://core.telegram.org/bots/inline). Don't
### [`paymentbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/paymentbot.py) ### [`paymentbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/paymentbot.py)
A basic example of a bot that can accept payments. Don't forget to enable and configure payments with [@BotFather](https://telegram.me/BotFather). A basic example of a bot that can accept payments. Don't forget to enable and configure payments with [@BotFather](https://telegram.me/BotFather).
### [`persistentconversationbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/persistentconversationbot.py)
A basic example of a bot store conversation state and user_data over multiple restarts.
## Pure API ## Pure API
The [`echobot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/echobot.py) example uses only the pure, "bare-metal" API wrapper. The [`echobot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/echobot.py) example uses only the pure, "bare-metal" API wrapper.

View file

@ -0,0 +1,170 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Simple Bot to reply to Telegram messages
# This program is dedicated to the public domain under the CC0 license.
"""
This Bot uses the Updater class to handle the bot.
First, a few callback functions are defined. Then, those functions are passed to
the Dispatcher and registered at their respective places.
Then, the bot is started and runs until we press Ctrl-C on the command line.
Usage:
Example of a bot-user conversation using ConversationHandler.
Send /start to initiate the conversation.
Press Ctrl-C on the command line or send a signal to the process to stop the
bot.
"""
from telegram import ReplyKeyboardMarkup
from telegram.ext import (Updater, CommandHandler, MessageHandler, Filters, RegexHandler,
ConversationHandler, PicklePersistence)
import logging
# Enable logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO)
logger = logging.getLogger(__name__)
CHOOSING, TYPING_REPLY, TYPING_CHOICE = range(3)
reply_keyboard = [['Age', 'Favourite colour'],
['Number of siblings', 'Something else...'],
['Done']]
markup = ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)
def facts_to_str(user_data):
facts = list()
for key, value in user_data.items():
facts.append('{} - {}'.format(key, value))
return "\n".join(facts).join(['\n', '\n'])
def start(bot, update, user_data):
reply_text = "Hi! My name is Doctor Botter."
if user_data:
reply_text += " You already told me your {}. Why don't you tell me something more " \
"about yourself? Or change enything I " \
"already know.".format(", ".join(user_data.keys()))
else:
reply_text += " I will hold a more complex conversation with you. Why don't you tell me " \
"something about yourself?"
update.message.reply_text(reply_text, reply_markup=markup)
return CHOOSING
def regular_choice(bot, update, user_data):
text = update.message.text
user_data['choice'] = text
if user_data.get(text):
reply_text = 'Your {}, I already know the following ' \
'about that: {}'.format(text.lower(), user_data[text.lower()])
else:
reply_text = 'Your {}? Yes, I would love to hear about that!'.format(text.lower())
update.message.reply_text(reply_text)
return TYPING_REPLY
def custom_choice(bot, update):
update.message.reply_text('Alright, please send me the category first, '
'for example "Most impressive skill"')
return TYPING_CHOICE
def received_information(bot, update, user_data):
text = update.message.text
category = user_data['choice']
user_data[category] = text.lower()
del user_data['choice']
update.message.reply_text("Neat! Just so you know, this is what you already told me:"
"{}"
"You can tell me more, or change your opinion on "
"something.".format(facts_to_str(user_data)), reply_markup=markup)
return CHOOSING
def show_data(bot, update, user_data):
update.message.reply_text("This is what you already told me:"
"{}".format(facts_to_str(user_data)))
def done(bot, update, user_data):
if 'choice' in user_data:
del user_data['choice']
update.message.reply_text("I learned these facts about you:"
"{}"
"Until next time!".format(facts_to_str(user_data)))
return ConversationHandler.END
def error(bot, update, error):
"""Log Errors caused by Updates."""
logger.warning('Update "%s" caused error "%s"', update, error)
def main():
# Create the Updater and pass it your bot's token.
pp = PicklePersistence(filename='conversationbot')
updater = Updater("TOKEN", persistence=pp)
# Get the dispatcher to register handlers
dp = updater.dispatcher
# Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY
conv_handler = ConversationHandler(
entry_points=[CommandHandler('start', start, pass_user_data=True)],
states={
CHOOSING: [RegexHandler('^(Age|Favourite colour|Number of siblings)$',
regular_choice,
pass_user_data=True),
RegexHandler('^Something else...$',
custom_choice),
],
TYPING_CHOICE: [MessageHandler(Filters.text,
regular_choice,
pass_user_data=True),
],
TYPING_REPLY: [MessageHandler(Filters.text,
received_information,
pass_user_data=True),
],
},
fallbacks=[RegexHandler('^Done$', done, pass_user_data=True)],
name="my_conversation",
persistent=True
)
dp.add_handler(conv_handler)
show_data_handler = CommandHandler('show_data', show_data, pass_user_data=True)
dp.add_handler(show_data_handler)
# log all errors
dp.add_error_handler(error)
# Start the Bot
updater.start_polling()
# Run the bot until you press Ctrl-C or the process receives SIGINT,
# SIGTERM or SIGABRT. This should be used most of the time, since
# start_polling() is non-blocking and will stop the bot gracefully.
updater.idle()
if __name__ == '__main__':
main()

View file

@ -18,6 +18,9 @@
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""Extensions over the Telegram Bot API to facilitate bot making""" """Extensions over the Telegram Bot API to facilitate bot making"""
from .basepersistence import BasePersistence
from .picklepersistence import PicklePersistence
from .dictpersistence import DictPersistence
from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async
from .jobqueue import JobQueue, Job from .jobqueue import JobQueue, Job
from .updater import Updater from .updater import Updater
@ -43,4 +46,5 @@ __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler',
'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler', 'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler',
'StringRegexHandler', 'TypeHandler', 'ConversationHandler', 'StringRegexHandler', 'TypeHandler', 'ConversationHandler',
'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', 'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue',
'DispatcherHandlerStop', 'run_async') 'DispatcherHandlerStop', 'run_async', 'BasePersistence', 'PicklePersistence',
'DictPersistence')

View file

@ -0,0 +1,123 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the BasePersistence class."""
class BasePersistence(object):
"""Interface class for adding persistence to your bot.
Subclass this object for different implementations of a persistent bot.
All relevant methods must be overwritten. This means:
* If :attr:`store_chat_data` is ``True`` you must overwrite :meth:`get_chat_data` and
:meth:`update_chat_data`.
* If :attr:`store_user_data` is ``True`` you must overwrite :meth:`get_user_data` and
:meth:`update_user_data`.
* If you want to store conversation data with :class:`telegram.ext.ConversationHandler`, you
must overwrite :meth:`get_conversations` and :meth:`update_conversation`.
* :meth:`flush` will be called when the bot is shutdown.
Attributes:
store_user_data (:obj:`bool`): Optional, Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Optional. Whether chat_data should be saved by this
persistence class.
Args:
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether chat_data should be saved by this
persistence class. Default is ``True`` .
"""
def __init__(self, store_user_data=True, store_chat_data=True):
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
def get_user_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the user_data if stored, or an empty
``defaultdict(dict)``.
Returns:
:obj:`defaultdict`: The restored user data.
"""
raise NotImplementedError
def get_chat_data(self):
""""Will be called by :class:`telegram.ext.Dispatcher` upon creation with a
persistence object. It should return the chat_data if stored, or an empty
``defaultdict(dict)``.
Returns:
:obj:`defaultdict`: The restored chat data.
"""
raise NotImplementedError
def get_conversations(self, name):
""""Will be called by :class:`telegram.ext.Dispatcher` when a
:class:`telegram.ext.ConversationHandler` is added if
:attr:`telegram.ext.ConversationHandler.persistent` is ``True``.
It should return the conversations for the handler with `name` or an empty ``dict``
Args:
name (:obj:`str`): The handlers name.
Returns:
:obj:`dict`: The restored conversations for the handler.
"""
raise NotImplementedError
def update_conversation(self, name, key, new_state):
"""Will be called when a :attr:`telegram.ext.ConversationHandler.update_state`
is called. this allows the storeage of the new state in the persistence.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
raise NotImplementedError
def update_user_data(self, user_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
raise NotImplementedError
def update_chat_data(self, chat_id, data):
"""Will be called by the :class:`telegram.ext.Dispatcher` after a handler has
handled an update.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[user_id].
"""
raise NotImplementedError
def flush(self):
"""Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the
persistence a chance to finish up saving or close a database connection gracefully. If this
is not of any importance just pass will be sufficient.
"""
pass

View file

@ -79,6 +79,10 @@ class ConversationHandler(Handler):
conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`): Optional. When this handler conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`): Optional. When this handler
is inactive more than this timeout (in seconds), it will be automatically ended. If is inactive more than this timeout (in seconds), it will be automatically ended. If
this value is 0 (default), there will be no timeout. this value is 0 (default), there will be no timeout.
name (:obj:`str`): Optional. The name for this conversationhandler. Required for
persistence
persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be
saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater`
Args: Args:
entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can
@ -113,6 +117,10 @@ class ConversationHandler(Handler):
conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`, optional): When this handler conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`, optional): When this handler
is inactive more than this timeout (in seconds), it will be automatically ended. If is inactive more than this timeout (in seconds), it will be automatically ended. If
this value is 0 or None (default), there will be no timeout. this value is 0 or None (default), there will be no timeout.
name (:obj:`str`, optional): The name for this conversationhandler. Required for
persistence
persistent (:obj:`bool`, optional): If the conversations dict for this handler should be
saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater`
Raises: Raises:
ValueError ValueError
@ -131,7 +139,9 @@ class ConversationHandler(Handler):
per_chat=True, per_chat=True,
per_user=True, per_user=True,
per_message=False, per_message=False,
conversation_timeout=None): conversation_timeout=None,
name=None,
persistent=False):
self.entry_points = entry_points self.entry_points = entry_points
self.states = states self.states = states
@ -144,6 +154,13 @@ class ConversationHandler(Handler):
self.per_chat = per_chat self.per_chat = per_chat
self.per_message = per_message self.per_message = per_message
self.conversation_timeout = conversation_timeout self.conversation_timeout = conversation_timeout
self.name = name
if persistent and not self.name:
raise ValueError("Conversations can't be persistent when handler is unnamed.")
self.persistent = persistent
self.persistence = None
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
Set by dispatcher"""
self.timeout_jobs = dict() self.timeout_jobs = dict()
self.conversations = dict() self.conversations = dict()
@ -318,14 +335,21 @@ class ConversationHandler(Handler):
if new_state == self.END: if new_state == self.END:
if key in self.conversations: if key in self.conversations:
del self.conversations[key] del self.conversations[key]
if self.persistent:
self.persistence.update_conversation(self.name, key, None)
else: else:
pass pass
elif isinstance(new_state, Promise): elif isinstance(new_state, Promise):
self.conversations[key] = (self.conversations.get(key), new_state) self.conversations[key] = (self.conversations.get(key), new_state)
if self.persistent:
self.persistence.update_conversation(self.name, key,
(self.conversations.get(key), new_state))
elif new_state is not None: elif new_state is not None:
self.conversations[key] = new_state self.conversations[key] = new_state
if self.persistent:
self.persistence.update_conversation(self.name, key, new_state)
def _trigger_timeout(self, bot, job): def _trigger_timeout(self, bot, job):
del self.timeout_jobs[job.context] del self.timeout_jobs[job.context]

View file

@ -0,0 +1,194 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the DictPersistence class."""
from telegram.utils.helpers import decode_user_chat_data_from_json,\
decode_conversations_from_json, enocde_conversations_to_json
try:
import ujson as json
except ImportError:
import json
from collections import defaultdict
from telegram.ext import BasePersistence
class DictPersistence(BasePersistence):
"""Using python's dicts and json for making you bot persistent.
Attributes:
store_user_data (:obj:`bool`): Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Whether chat_data should be saved by this
persistence class.
Args:
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
user_data on creating this persistence. Default is ``""``.
chat_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
chat_data on creating this persistence. Default is ``""``.
conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct
conversation on creating this persistence. Default is ``""``.
"""
def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='',
chat_data_json='', conversations_json=''):
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self._user_data = None
self._chat_data = None
self._conversations = None
self._user_data_json = None
self._chat_data_json = None
self._conversations_json = None
if user_data_json:
try:
self._user_data = decode_user_chat_data_from_json(user_data_json)
self._user_data_json = user_data_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize user_data_json. Not valid JSON")
if chat_data_json:
try:
self._chat_data = decode_user_chat_data_from_json(chat_data_json)
self._chat_data_json = chat_data_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize chat_data_json. Not valid JSON")
if conversations_json:
try:
self._conversations = decode_conversations_from_json(conversations_json)
self._conversations_json = conversations_json
except (ValueError, AttributeError):
raise TypeError("Unable to deserialize conversations_json. Not valid JSON")
@property
def user_data(self):
""":obj:`dict`: The user_data as a dict"""
return self._user_data
@property
def user_data_json(self):
""":obj:`str`: The user_data serialized as a JSON-string."""
if self._user_data_json:
return self._user_data_json
else:
return json.dumps(self.user_data)
@property
def chat_data(self):
""":obj:`dict`: The chat_data as a dict"""
return self._chat_data
@property
def chat_data_json(self):
""":obj:`str`: The chat_data serialized as a JSON-string."""
if self._chat_data_json:
return self._chat_data_json
else:
return json.dumps(self.chat_data)
@property
def conversations(self):
""":obj:`dict`: The conversations as a dict"""
return self._conversations
@property
def conversations_json(self):
""":obj:`str`: The conversations serialized as a JSON-string."""
if self._conversations_json:
return self._conversations_json
else:
return enocde_conversations_to_json(self.conversations)
def get_user_data(self):
"""Returns the user_data created from the ``user_data_json`` or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.user_data:
pass
else:
self._user_data = defaultdict(dict)
return self.user_data.copy()
def get_chat_data(self):
"""Returns the chat_data created from the ``chat_data_json`` or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.chat_data:
pass
else:
self._chat_data = defaultdict(dict)
return self.chat_data.copy()
def get_conversations(self, name):
"""Returns the conversations created from the ``conversations_json`` or an empty
defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.conversations:
pass
else:
self._conversations = {}
return self.conversations.get(name, {}).copy()
def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self._conversations.setdefault(name, {}).get(key) == new_state:
return
self._conversations[name][key] = new_state
self._conversations_json = None
def update_user_data(self, user_id, data):
"""Will update the user_data (if changed).
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self._user_data.get(user_id) == data:
return
self._user_data[user_id] = data
self._user_data_json = None
def update_chat_data(self, chat_id, data):
"""Will update the chat_data (if changed).
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self._chat_data.get(chat_id) == data:
return
self._chat_data[chat_id] = data
self._chat_data_json = None

View file

@ -33,6 +33,7 @@ from future.builtins import range
from telegram import TelegramError from telegram import TelegramError
from telegram.ext.handler import Handler from telegram.ext.handler import Handler
from telegram.utils.promise import Promise from telegram.utils.promise import Promise
from telegram.ext import BasePersistence
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
DEFAULT_GROUP = 0 DEFAULT_GROUP = 0
@ -48,6 +49,7 @@ def run_async(func):
Note: Use this decorator to run handlers asynchronously. Note: Use this decorator to run handlers asynchronously.
""" """
@wraps(func) @wraps(func)
def async_func(*args, **kwargs): def async_func(*args, **kwargs):
return Dispatcher.get_instance().run_async(func, *args, **kwargs) return Dispatcher.get_instance().run_async(func, *args, **kwargs)
@ -70,6 +72,10 @@ class Dispatcher(object):
instance to pass onto handler callbacks. instance to pass onto handler callbacks.
workers (:obj:`int`): Number of maximum concurrent worker threads for the ``@run_async`` workers (:obj:`int`): Number of maximum concurrent worker threads for the ``@run_async``
decorator. decorator.
user_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the user.
chat_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the chat.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts
Args: Args:
bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers.
@ -78,6 +84,8 @@ class Dispatcher(object):
instance to pass onto handler callbacks. instance to pass onto handler callbacks.
workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the
``@run_async`` decorator. defaults to 4. ``@run_async`` decorator. defaults to 4.
persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to
store data that should be persistent over restarts
""" """
@ -86,16 +94,30 @@ class Dispatcher(object):
__singleton = None __singleton = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None): def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None,
persistence=None):
self.bot = bot self.bot = bot
self.update_queue = update_queue self.update_queue = update_queue
self.job_queue = job_queue
self.workers = workers self.workers = workers
self.user_data = defaultdict(dict) self.user_data = defaultdict(dict)
""":obj:`dict`: A dictionary handlers can use to store data for the user."""
self.chat_data = defaultdict(dict) self.chat_data = defaultdict(dict)
""":obj:`dict`: A dictionary handlers can use to store data for the chat.""" if persistence:
if not isinstance(persistence, BasePersistence):
raise TypeError("persistence should be based on telegram.ext.BasePersistence")
self.persistence = persistence
if self.persistence.store_user_data:
self.user_data = self.persistence.get_user_data()
if not isinstance(self.user_data, defaultdict):
raise ValueError("user_data must be of type defaultdict")
if self.persistence.store_chat_data:
self.chat_data = self.persistence.get_chat_data()
if not isinstance(self.chat_data, defaultdict):
raise ValueError("chat_data must be of type defaultdict")
else:
self.persistence = None
self.job_queue = job_queue
self.handlers = {} self.handlers = {}
"""Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]: Holds the handlers per group.""" """Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]: Holds the handlers per group."""
self.groups = [] self.groups = []
@ -277,6 +299,19 @@ class Dispatcher(object):
try: try:
for handler in (x for x in self.handlers[group] if x.check_update(update)): for handler in (x for x in self.handlers[group] if x.check_update(update)):
handler.handle_update(update, self) handler.handle_update(update, self)
if self.persistence:
if self.persistence.store_chat_data and update.effective_chat.id:
chat_id = update.effective_chat.id
try:
self.persistence.update_chat_data(chat_id, self.chat_data[chat_id])
except Exception:
self.logger.exception('Saving chat data raised an error')
if self.persistence.store_user_data and update.effective_user.id:
user_id = update.effective_user.id
try:
self.persistence.update_user_data(user_id, self.user_data[user_id])
except Exception:
self.logger.exception('Saving user data raised an error')
break break
# Stop processing with any other handler. # Stop processing with any other handler.
@ -324,11 +359,20 @@ class Dispatcher(object):
group (:obj:`int`, optional): The group identifier. Default is 0. group (:obj:`int`, optional): The group identifier. Default is 0.
""" """
# Unfortunately due to circular imports this has to be here
from .conversationhandler import ConversationHandler
if not isinstance(handler, Handler): if not isinstance(handler, Handler):
raise TypeError('handler is not an instance of {0}'.format(Handler.__name__)) raise TypeError('handler is not an instance of {0}'.format(Handler.__name__))
if not isinstance(group, int): if not isinstance(group, int):
raise TypeError('group is not int') raise TypeError('group is not int')
if isinstance(handler, ConversationHandler) and handler.persistent:
if not self.persistence:
raise ValueError(
"Conversationhandler {} can not be persistent if dispatcher has no "
"persistence".format(handler.name))
handler.conversations = self.persistence.get_conversations(handler.name)
handler.persistence = self.persistence
if group not in self.handlers: if group not in self.handlers:
self.handlers[group] = list() self.handlers[group] = list()

View file

@ -0,0 +1,235 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the PicklePersistence class."""
import pickle
from collections import defaultdict
from telegram.ext import BasePersistence
class PicklePersistence(BasePersistence):
"""Using python's builtin pickle for making you bot persistent.
Attributes:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
is false this will be used as a prefix.
store_user_data (:obj:`bool`): Optional. Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Optional. Whether user_data should be saved by this
persistence class.
single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of
`filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is
``True``.
on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
Args:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
is false this will be used as a prefix.
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
single_file (:obj:`bool`, optional): When ``False`` will store 3 sperate files of
`filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is
``True``.
on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
"""
def __init__(self, filename, store_user_data=True, store_chat_data=True, singe_file=True,
on_flush=False):
self.filename = filename
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self.single_file = singe_file
self.on_flush = on_flush
self.user_data = None
self.chat_data = None
self.conversations = None
def load_singlefile(self):
try:
filename = self.filename
with open(self.filename, "rb") as f:
all = pickle.load(f)
self.user_data = defaultdict(dict, all['user_data'])
self.chat_data = defaultdict(dict, all['chat_data'])
self.conversations = all['conversations']
except IOError:
self.conversations = {}
self.user_data = defaultdict(dict)
self.chat_data = defaultdict(dict)
except pickle.UnpicklingError:
raise TypeError("File {} does not contain valid pickle data".format(filename))
except Exception:
raise TypeError("Something went wrong unpickling {}".format(filename))
def load_file(self, filename):
try:
with open(filename, "rb") as f:
return pickle.load(f)
except IOError:
return None
except pickle.UnpicklingError:
raise TypeError("File {} does not contain valid pickle data".format(filename))
except Exception:
raise TypeError("Something went wrong unpickling {}".format(filename))
def dump_singlefile(self):
with open(self.filename, "wb") as f:
all = {'conversations': self.conversations, 'user_data': self.user_data,
'chat_data': self.chat_data}
pickle.dump(all, f)
def dump_file(self, filename, data):
with open(filename, "wb") as f:
pickle.dump(data, f)
def get_user_data(self):
"""Returns the user_data from the pickle file if it exsists or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored user data.
"""
if self.user_data:
pass
elif not self.single_file:
filename = "{}_user_data".format(self.filename)
data = self.load_file(filename)
if not data:
data = defaultdict(dict)
else:
data = defaultdict(dict, data)
self.user_data = data
else:
self.load_singlefile()
return self.user_data.copy()
def get_chat_data(self):
"""Returns the chat_data from the pickle file if it exsists or an empty defaultdict.
Returns:
:obj:`defaultdict`: The restored chat data.
"""
if self.chat_data:
pass
elif not self.single_file:
filename = "{}_chat_data".format(self.filename)
data = self.load_file(filename)
if not data:
data = defaultdict(dict)
else:
data = defaultdict(dict, data)
self.chat_data = data
else:
self.load_singlefile()
return self.chat_data.copy()
def get_conversations(self, name):
"""Returns the conversations from the pickle file if it exsists or an empty defaultdict.
Args:
name (:obj:`str`): The handlers name.
Returns:
:obj:`dict`: The restored conversations for the handler.
"""
if self.conversations:
pass
elif not self.single_file:
filename = "{}_conversations".format(self.filename)
data = self.load_file(filename)
if not data:
data = {name: {}}
self.conversations = data
else:
self.load_singlefile()
return self.conversations.get(name, {}).copy()
def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler and depending on :attr:`on_flush`
save the pickle file.
Args:
name (:obj:`str`): The handlers name.
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self.conversations.setdefault(name, {}).get(key) == new_state:
return
self.conversations[name][key] = new_state
if not self.on_flush:
if not self.single_file:
filename = "{}_conversations".format(self.filename)
self.dump_file(filename, self.conversations)
else:
self.dump_singlefile()
def update_user_data(self, user_id, data):
"""Will update the user_data (if changed) and depending on :attr:`on_flush` save the
pickle file.
Args:
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_user_data".format(self.filename)
self.dump_file(filename, self.user_data)
else:
self.dump_singlefile()
def update_chat_data(self, chat_id, data):
"""Will update the chat_data (if changed) and depending on :attr:`on_flush` save the
pickle file.
Args:
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_chat_data".format(self.filename)
self.dump_file(filename, self.chat_data)
else:
self.dump_singlefile()
def flush(self):
"""If :attr:`on_flush` is set to ``True``. Will save all data in memory to pickle file(s). If
it's ``False`` will just pass.
"""
if not self.on_flush:
pass
else:
if self.single_file:
self.dump_singlefile()
else:
self.dump_file("{}_user_data".format(self.filename), self.user_data)
self.dump_file("{}_chat_data".format(self.filename), self.chat_data)
self.dump_file("{}_conversations".format(self.filename), self.conversations)

View file

@ -55,6 +55,8 @@ class Updater(object):
dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that handles the updates and dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that handles the updates and
dispatches them to the handlers. dispatches them to the handlers.
running (:obj:`bool`): Indicates if the updater is running. running (:obj:`bool`): Indicates if the updater is running.
persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to
store data that should be persistent over restarts.
Args: Args:
token (:obj:`str`, optional): The bot's token given by the @BotFather. token (:obj:`str`, optional): The bot's token given by the @BotFather.
@ -73,6 +75,8 @@ class Updater(object):
`telegram.utils.request.Request` object (ignored if `bot` argument is used). The `telegram.utils.request.Request` object (ignored if `bot` argument is used). The
request_kwargs are very useful for the advanced users who would like to control the request_kwargs are very useful for the advanced users who would like to control the
default timeouts and/or control the proxy used for http communication. default timeouts and/or control the proxy used for http communication.
persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to
store data that should be persistent over restarts.
Note: Note:
You must supply either a :attr:`bot` or a :attr:`token` argument. You must supply either a :attr:`bot` or a :attr:`token` argument.
@ -92,7 +96,8 @@ class Updater(object):
private_key=None, private_key=None,
private_key_password=None, private_key_password=None,
user_sig_handler=None, user_sig_handler=None,
request_kwargs=None): request_kwargs=None,
persistence=None):
if (token is None) and (bot is None): if (token is None) and (bot is None):
raise ValueError('`token` or `bot` must be passed') raise ValueError('`token` or `bot` must be passed')
@ -129,12 +134,14 @@ class Updater(object):
self.update_queue = Queue() self.update_queue = Queue()
self.job_queue = JobQueue(self.bot) self.job_queue = JobQueue(self.bot)
self.__exception_event = Event() self.__exception_event = Event()
self.persistence = persistence
self.dispatcher = Dispatcher( self.dispatcher = Dispatcher(
self.bot, self.bot,
self.update_queue, self.update_queue,
job_queue=self.job_queue, job_queue=self.job_queue,
workers=workers, workers=workers,
exception_event=self.__exception_event) exception_event=self.__exception_event,
persistence=self.persistence)
self.last_update_id = 0 self.last_update_id = 0
self.running = False self.running = False
self.is_idle = False self.is_idle = False
@ -485,6 +492,8 @@ class Updater(object):
if self.running: if self.running:
self.logger.info('Received signal {} ({}), stopping...'.format( self.logger.info('Received signal {} ({}), stopping...'.format(
signum, get_signal_name(signum))) signum, get_signal_name(signum)))
if self.persistence:
self.persistence.flush()
self.stop() self.stop()
if self.user_sig_handler: if self.user_sig_handler:
self.user_sig_handler(signum, frame) self.user_sig_handler(signum, frame)

View file

@ -17,6 +17,12 @@
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains helper functions.""" """This module contains helper functions."""
from collections import defaultdict
try:
import ujson as json
except ImportError:
import json
from html import escape from html import escape
import re import re
@ -139,3 +145,65 @@ def effective_message_type(entity):
return i return i
return None return None
def enocde_conversations_to_json(conversations):
"""Helper method to encode a conversations dict (that uses tuples as keys) to a
JSON-serializable way. Use :attr:`_decode_conversations_from_json` to decode.
Args:
conversations (:obj:`dict`): The conversations dict to transofrm to JSON.
Returns:
:obj:`str`: The JSON-serialized conversations dict
"""
tmp = {}
for handler, states in conversations.items():
tmp[handler] = {}
for key, state in states.items():
tmp[handler][json.dumps(key)] = state
return json.dumps(tmp)
def decode_conversations_from_json(json_string):
"""Helper method to decode a conversations dict (that uses tuples as keys) from a
JSON-string created with :attr:`_encode_conversations_to_json`.
Args:
json_string (:obj:`str`): The conversations dict as JSON string.
Returns:
:obj:`dict`: The conversations dict after decoding
"""
tmp = json.loads(json_string)
conversations = {}
for handler, states in tmp.items():
conversations[handler] = {}
for key, state in states.items():
conversations[handler][tuple(json.loads(key))] = state
return conversations
def decode_user_chat_data_from_json(data):
"""Helper method to decode chat or user data (that uses ints as keys) from a
JSON-string.
Args:
data (:obj:`str`): The user/chat_data dict as JSON string.
Returns:
:obj:`dict`: The user/chat_data defaultdict after decoding
"""
tmp = defaultdict(dict)
decoded_data = json.loads(data)
for user, data in decoded_data.items():
user = int(user)
tmp[user] = {}
for key, value in data.items():
try:
key = int(key)
except ValueError:
pass
tmp[user][key] = value
return tmp

View file

@ -96,6 +96,12 @@ class TestConversationHandler(object):
ConversationHandler(self.entry_points, self.states, self.fallbacks, ConversationHandler(self.entry_points, self.states, self.fallbacks,
per_chat=False, per_user=False, per_message=False) per_chat=False, per_user=False, per_message=False)
def test_name_and_persistent(self, dp):
with pytest.raises(ValueError, match="when handler is unnamed"):
dp.add_handler(ConversationHandler([], {}, [], persistent=True))
c = ConversationHandler([], {}, [], name="handler", persistent=True)
assert c.name == "handler"
def test_conversation_handler(self, dp, bot, user1, user2): def test_conversation_handler(self, dp, bot, user1, user2):
handler = ConversationHandler(entry_points=self.entry_points, states=self.states, handler = ConversationHandler(entry_points=self.entry_points, states=self.states,
fallbacks=self.fallbacks) fallbacks=self.fallbacks)

View file

@ -82,6 +82,16 @@ class TestDispatcher(object):
sleep(.1) sleep(.1)
assert self.received is None assert self.received is None
def test_construction_with_bad_persistence(self, caplog, bot):
class my_per:
def __init__(self):
self.store_user_data = False
self.store_chat_data = False
with pytest.raises(TypeError,
match='persistence should be based on telegram.ext.BasePersistence'):
Dispatcher(bot, None, persistence=my_per())
def test_error_handler_that_raises_errors(self, dp): def test_error_handler_that_raises_errors(self, dp):
""" """
Make sure that errors raised in error handlers don't break the main loop of the dispatcher Make sure that errors raised in error handlers don't break the main loop of the dispatcher

751
tests/test_persistence.py Normal file
View file

@ -0,0 +1,751 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2018
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
from telegram.utils.helpers import enocde_conversations_to_json
try:
import ujson as json
except ImportError:
import json
import logging
import os
import pickle
from collections import defaultdict
import pytest
from telegram import Update, Message, User, Chat
from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \
PicklePersistence, CommandHandler, DictPersistence
@pytest.fixture(scope="function")
def base_persistence():
return BasePersistence(store_chat_data=True, store_user_data=True)
@pytest.fixture(scope="function")
def chat_data():
return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}})
@pytest.fixture(scope="function")
def user_data():
return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}})
@pytest.fixture(scope='function')
def conversations():
return {'name1': {(123, 123): 3, (456, 654): 4},
'name2': {(123, 321): 1, (890, 890): 2}}
@pytest.fixture(scope="function")
def updater(bot, base_persistence):
base_persistence.store_chat_data = False
base_persistence.store_user_data = False
u = Updater(bot=bot, persistence=base_persistence)
base_persistence.store_chat_data = True
base_persistence.store_user_data = True
return u
class TestBasePersistence(object):
def test_creation(self, base_persistence):
assert base_persistence.store_chat_data
assert base_persistence.store_user_data
with pytest.raises(NotImplementedError):
base_persistence.get_chat_data()
with pytest.raises(NotImplementedError):
base_persistence.get_user_data()
with pytest.raises(NotImplementedError):
base_persistence.get_conversations("test")
with pytest.raises(NotImplementedError):
base_persistence.update_chat_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_user_data(None, None)
with pytest.raises(NotImplementedError):
base_persistence.update_conversation(None, None, None)
def test_implementation(self, updater, base_persistence):
dp = updater.dispatcher
assert dp.persistence == base_persistence
def test_conversationhandler_addition(self, dp, base_persistence):
with pytest.raises(ValueError, match="when handler is unnamed"):
ConversationHandler([], [], [], persistent=True)
with pytest.raises(ValueError, match="if dispatcher has no persistence"):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
dp.persistence = base_persistence
with pytest.raises(NotImplementedError):
dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler"))
def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data):
def get_user_data():
return "test"
def get_chat_data():
return "test"
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
with pytest.raises(ValueError, match="user_data must be of type defaultdict"):
u = Updater(bot=bot, persistence=base_persistence)
def get_user_data():
return user_data
base_persistence.get_user_data = get_user_data
with pytest.raises(ValueError, match="chat_data must be of type defaultdict"):
u = Updater(bot=bot, persistence=base_persistence)
def get_chat_data():
return chat_data
base_persistence.get_chat_data = get_chat_data
u = Updater(bot=bot, persistence=base_persistence)
assert u.dispatcher.chat_data == chat_data
assert u.dispatcher.user_data == user_data
u.dispatcher.chat_data[442233]['test5'] = 'test6'
assert u.dispatcher.chat_data[442233]['test5'] == 'test6'
def test_dispatcher_integration_handlers(self, caplog, bot, base_persistence,
chat_data, user_data):
def get_user_data():
return user_data
def get_chat_data():
return chat_data
base_persistence.get_user_data = get_user_data
base_persistence.get_chat_data = get_chat_data
# base_persistence.update_chat_data = lambda x: x
# base_persistence.update_user_data = lambda x: x
updater = Updater(bot=bot, persistence=base_persistence)
dp = updater.dispatcher
def callback_known_user(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail('user_data corrupt')
def callback_known_chat(bot, update, user_data, chat_data):
if not chat_data['test3'] == 'test4':
pytest.fail('chat_data corrupt')
def callback_unknown_user_or_chat(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail('user_data corrupt')
if not chat_data == {}:
pytest.fail('chat_data corrupt')
user_data[1] = 'test7'
chat_data[2] = 'test8'
known_user = MessageHandler(Filters.user(user_id=12345), callback_known_user,
pass_chat_data=True, pass_user_data=True)
known_chat = MessageHandler(Filters.chat(chat_id=-67890), callback_known_chat,
pass_chat_data=True, pass_user_data=True)
unknown = MessageHandler(Filters.all, callback_unknown_user_or_chat, pass_chat_data=True,
pass_user_data=True)
dp.add_handler(known_user)
dp.add_handler(known_chat)
dp.add_handler(unknown)
user1 = User(id=12345, first_name='test user', is_bot=False)
user2 = User(id=54321, first_name='test user', is_bot=False)
chat1 = Chat(id=-67890, type='group')
chat2 = Chat(id=-987654, type='group')
m = Message(1, user1, None, chat2)
u = Update(0, m)
with caplog.at_level(logging.ERROR):
dp.process_update(u)
rec = caplog.records[-1]
assert rec.msg == 'Saving user data raised an error'
assert rec.levelname == 'ERROR'
rec = caplog.records[-2]
assert rec.msg == 'Saving chat data raised an error'
assert rec.levelname == 'ERROR'
m.from_user = user2
m.chat = chat1
u = Update(1, m)
dp.process_update(u)
m.chat = chat2
u = Update(2, m)
def save_chat_data(data):
if -987654 not in data:
pytest.fail()
def save_user_data(data):
if 54321 not in data:
pytest.fail()
base_persistence.update_chat_data = save_chat_data
base_persistence.update_user_data = save_user_data
dp.process_update(u)
assert dp.user_data[54321][1] == 'test7'
assert dp.chat_data[-987654][2] == 'test8'
@pytest.fixture(scope='function')
def pickle_persistence():
return PicklePersistence(filename='pickletest',
store_user_data=True,
store_chat_data=True,
singe_file=False,
on_flush=False)
@pytest.fixture(scope='function')
def bad_pickle_files():
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
with open(name, 'w') as f:
f.write('(())')
yield True
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
os.remove(name)
@pytest.fixture(scope='function')
def good_pickle_files(user_data, chat_data, conversations):
all = {'user_data': user_data, 'chat_data': chat_data, 'conversations': conversations}
with open('pickletest_user_data', 'wb') as f:
pickle.dump(user_data, f)
with open('pickletest_chat_data', 'wb') as f:
pickle.dump(chat_data, f)
with open('pickletest_conversations', 'wb') as f:
pickle.dump(conversations, f)
with open('pickletest', 'wb') as f:
pickle.dump(all, f)
yield True
for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations',
'pickletest']:
os.remove(name)
@pytest.fixture(scope='function')
def update(bot):
user = User(id=321, first_name='test_user', is_bot=False)
chat = Chat(id=123, type='group')
message = Message(1, user, None, chat, text="Hi there", bot=bot)
return Update(0, message=message)
class TestPickelPersistence(object):
def test_no_files_present_multi_file(self, pickle_persistence):
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_conversations('noname') == {}
assert pickle_persistence.get_conversations('noname') == {}
def test_no_files_present_single_file(self, pickle_persistence):
pickle_persistence.single_file = True
assert pickle_persistence.get_user_data() == defaultdict(dict)
assert pickle_persistence.get_chat_data() == defaultdict(dict)
assert pickle_persistence.get_conversations('noname') == {}
def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files):
with pytest.raises(TypeError, match='pickletest_user_data'):
pickle_persistence.get_user_data()
with pytest.raises(TypeError, match='pickletest_chat_data'):
pickle_persistence.get_chat_data()
with pytest.raises(TypeError, match='pickletest_conversations'):
pickle_persistence.get_conversations('name')
def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files):
pickle_persistence.single_file = True
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_user_data()
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_chat_data()
with pytest.raises(TypeError, match='pickletest'):
pickle_persistence.get_conversations('name')
def test_with_good_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = pickle_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = pickle_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = pickle_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_with_good_single_file(self, pickle_persistence, good_pickle_files):
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = pickle_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = pickle_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = pickle_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_updating_multi_file(self, pickle_persistence, good_pickle_files):
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert conversations_test['name1'] == conversation1
def test_updating_single_file(self, pickle_persistence, good_pickle_files):
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert conversations_test['name1'] == conversation1
def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files):
# Should run without error
pickle_persistence.flush()
pickle_persistence.on_flush = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert not user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert not chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert not conversations_test['name1'] == conversation1
pickle_persistence.flush()
with open('pickletest_user_data', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f))
assert user_data_test == user_data
with open('pickletest_chat_data', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f))
assert chat_data_test == chat_data
with open('pickletest_conversations', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f))
assert conversations_test['name1'] == conversation1
def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files):
# Should run without error
pickle_persistence.flush()
pickle_persistence.on_flush = True
pickle_persistence.single_file = True
user_data = pickle_persistence.get_user_data()
user_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.user_data == user_data
pickle_persistence.update_user_data(54321, user_data[54321])
assert pickle_persistence.user_data == user_data
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert not user_data_test == user_data
chat_data = pickle_persistence.get_chat_data()
chat_data[54321]['test9'] = 'test 10'
assert not pickle_persistence.chat_data == chat_data
pickle_persistence.update_chat_data(54321, chat_data[54321])
assert pickle_persistence.chat_data == chat_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert not chat_data_test == chat_data
conversation1 = pickle_persistence.get_conversations('name1')
conversation1[(123, 123)] = 5
assert not pickle_persistence.conversations['name1'] == conversation1
pickle_persistence.update_conversation('name1', (123, 123), 5)
assert pickle_persistence.conversations['name1'] == conversation1
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert not conversations_test['name1'] == conversation1
pickle_persistence.flush()
with open('pickletest', 'rb') as f:
user_data_test = defaultdict(dict, pickle.load(f)['user_data'])
assert user_data_test == user_data
with open('pickletest', 'rb') as f:
chat_data_test = defaultdict(dict, pickle.load(f)['chat_data'])
assert chat_data_test == chat_data
with open('pickletest', 'rb') as f:
conversations_test = defaultdict(dict, pickle.load(f)['conversations'])
assert conversations_test['name1'] == conversation1
def test_with_handler(self, bot, update, pickle_persistence, good_pickle_files):
u = Updater(bot=bot, persistence=pickle_persistence)
dp = u.dispatcher
def first(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail()
if not chat_data == {}:
pytest.fail()
user_data['test1'] = 'test2'
chat_data['test3'] = 'test4'
def second(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail()
if not chat_data['test3'] == 'test4':
pytest.fail()
h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True)
h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True)
dp.add_handler(h1)
dp.process_update(update)
del (dp)
del (u)
del (pickle_persistence)
pickle_persistence_2 = PicklePersistence(filename='pickletest',
store_user_data=True,
store_chat_data=True,
singe_file=False,
on_flush=False)
u = Updater(bot=bot, persistence=pickle_persistence_2)
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)
def test_with_conversationHandler(self, dp, update, good_pickle_files, pickle_persistence):
dp.persistence = pickle_persistence
NEXT, NEXT2 = range(2)
def start(bot, update):
return NEXT
start = CommandHandler('start', start)
def next(bot, update):
return NEXT2
next = MessageHandler(None, next)
def next2(bot, update):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
update.message.text = '/start'
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == pickle_persistence.conversations['name2']
@classmethod
def teardown_class(cls):
try:
for name in ['pickletest_user_data', 'pickletest_chat_data',
'pickletest_conversations',
'pickletest']:
os.remove(name)
except Exception:
pass
@pytest.fixture(scope='function')
def user_data_json(user_data):
return json.dumps(user_data)
@pytest.fixture(scope='function')
def chat_data_json(chat_data):
return json.dumps(chat_data)
@pytest.fixture(scope='function')
def conversations_json(conversations):
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
{"[123, 321]": 1, "[890, 890]": 2}}"""
class TestDictPersistence(object):
def test_no_json_given(self):
dict_persistence = DictPersistence()
assert dict_persistence.get_user_data() == defaultdict(dict)
assert dict_persistence.get_chat_data() == defaultdict(dict)
assert dict_persistence.get_conversations('noname') == {}
def test_bad_json_string_given(self):
bad_user_data = 'thisisnojson99900()))('
bad_chat_data = 'thisisnojson99900()))('
bad_conversations = 'thisisnojson99900()))('
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files):
bad_user_data = '["this", "is", "json"]'
bad_chat_data = '["this", "is", "json"]'
bad_conversations = '["this", "is", "json"]'
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
def test_good_json_input(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
user_data = dict_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
assert user_data[67890][3] == 'test4'
assert user_data[54321] == {}
chat_data = dict_persistence.get_chat_data()
assert isinstance(chat_data, defaultdict)
assert chat_data[-12345]['test1'] == 'test2'
assert chat_data[-67890][3] == 'test4'
assert chat_data[-54321] == {}
conversation1 = dict_persistence.get_conversations('name1')
assert isinstance(conversation1, dict)
assert conversation1[(123, 123)] == 3
assert conversation1[(456, 654)] == 4
with pytest.raises(KeyError):
conversation1[(890, 890)]
conversation2 = dict_persistence.get_conversations('name2')
assert isinstance(conversation1, dict)
assert conversation2[(123, 321)] == 1
assert conversation2[(890, 890)] == 2
with pytest.raises(KeyError):
conversation2[(123, 123)]
def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json,
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
assert dict_persistence.user_data == user_data
assert dict_persistence.chat_data == chat_data
assert dict_persistence.conversations == conversations
def test_json_outputs(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
assert dict_persistence.user_data_json == user_data_json
assert dict_persistence.chat_data_json == chat_data_json
assert dict_persistence.conversations_json == conversations_json
def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json,
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
user_data_two = user_data.copy()
user_data_two.update({4: {5: 6}})
dict_persistence.update_user_data(4, {5: 6})
assert dict_persistence.user_data == user_data_two
assert dict_persistence.user_data_json != user_data_json
assert dict_persistence.user_data_json == json.dumps(user_data_two)
chat_data_two = chat_data.copy()
chat_data_two.update({7: {8: 9}})
dict_persistence.update_chat_data(7, {8: 9})
assert dict_persistence.chat_data == chat_data_two
assert dict_persistence.chat_data_json != chat_data_json
assert dict_persistence.chat_data_json == json.dumps(chat_data_two)
conversations_two = conversations.copy()
conversations_two.update({'name3': {(1, 2): 3}})
dict_persistence.update_conversation('name3', (1, 2), 3)
assert dict_persistence.conversations == conversations_two
assert dict_persistence.conversations_json != conversations_json
assert dict_persistence.conversations_json == enocde_conversations_to_json(
conversations_two)
def test_with_handler(self, bot, update):
dict_persistence = DictPersistence()
u = Updater(bot=bot, persistence=dict_persistence)
dp = u.dispatcher
def first(bot, update, user_data, chat_data):
if not user_data == {}:
pytest.fail()
if not chat_data == {}:
pytest.fail()
user_data['test1'] = 'test2'
chat_data[3] = 'test4'
def second(bot, update, user_data, chat_data):
if not user_data['test1'] == 'test2':
pytest.fail()
if not chat_data[3] == 'test4':
pytest.fail()
h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True)
h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True)
dp.add_handler(h1)
dp.process_update(update)
del (dp)
del (u)
user_data = dict_persistence.user_data_json
chat_data = dict_persistence.chat_data_json
del (dict_persistence)
dict_persistence_2 = DictPersistence(user_data_json=user_data,
chat_data_json=chat_data)
u = Updater(bot=bot, persistence=dict_persistence_2)
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)
def test_with_conversationHandler(self, dp, update, conversations_json):
dict_persistence = DictPersistence(conversations_json=conversations_json)
dp.persistence = dict_persistence
NEXT, NEXT2 = range(2)
def start(bot, update):
return NEXT
start = CommandHandler('start', start)
def next(bot, update):
return NEXT2
next = MessageHandler(None, next)
def next2(bot, update):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
update.message.text = '/start'
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == dict_persistence.conversations['name2']