diff --git a/docs/source/telegram.ext.basepersistence.rst b/docs/source/telegram.ext.basepersistence.rst new file mode 100644 index 000000000..013c052d8 --- /dev/null +++ b/docs/source/telegram.ext.basepersistence.rst @@ -0,0 +1,6 @@ +telegram.ext.BasePersistence +============================ + +.. autoclass:: telegram.ext.BasePersistence + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.dictpersistence.rst b/docs/source/telegram.ext.dictpersistence.rst new file mode 100644 index 000000000..08b78de37 --- /dev/null +++ b/docs/source/telegram.ext.dictpersistence.rst @@ -0,0 +1,6 @@ +telegram.ext.DictPersistence +============================ + +.. autoclass:: telegram.ext.DictPersistence + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.picklepersistence.rst b/docs/source/telegram.ext.picklepersistence.rst new file mode 100644 index 000000000..f8691b181 --- /dev/null +++ b/docs/source/telegram.ext.picklepersistence.rst @@ -0,0 +1,6 @@ +telegram.ext.PicklePersistence +============================== + +.. autoclass:: telegram.ext.PicklePersistence + :members: + :show-inheritance: diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst index d54e05425..9b469a200 100644 --- a/docs/source/telegram.ext.rst +++ b/docs/source/telegram.ext.rst @@ -29,3 +29,12 @@ Handlers telegram.ext.stringcommandhandler telegram.ext.stringregexhandler telegram.ext.typehandler + +Persistence +----------- + +.. toctree:: + + telegram.ext.basepersistence + telegram.ext.picklepersistence + telegram.ext.dictpersistence \ No newline at end of file diff --git a/examples/README.md b/examples/README.md index 81c51c7f7..de058567e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -25,5 +25,8 @@ A basic example of an [inline bot](https://core.telegram.org/bots/inline). Don't ### [`paymentbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/paymentbot.py) A basic example of a bot that can accept payments. Don't forget to enable and configure payments with [@BotFather](https://telegram.me/BotFather). +### [`persistentconversationbot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/persistentconversationbot.py) +A basic example of a bot store conversation state and user_data over multiple restarts. + ## Pure API The [`echobot.py`](https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples/echobot.py) example uses only the pure, "bare-metal" API wrapper. diff --git a/examples/persistentconversationbot.py b/examples/persistentconversationbot.py new file mode 100644 index 000000000..5d9acc2f0 --- /dev/null +++ b/examples/persistentconversationbot.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Simple Bot to reply to Telegram messages +# This program is dedicated to the public domain under the CC0 license. +""" +This Bot uses the Updater class to handle the bot. + +First, a few callback functions are defined. Then, those functions are passed to +the Dispatcher and registered at their respective places. +Then, the bot is started and runs until we press Ctrl-C on the command line. + +Usage: +Example of a bot-user conversation using ConversationHandler. +Send /start to initiate the conversation. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +from telegram import ReplyKeyboardMarkup +from telegram.ext import (Updater, CommandHandler, MessageHandler, Filters, RegexHandler, + ConversationHandler, PicklePersistence) + +import logging + +# Enable logging +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + level=logging.INFO) + +logger = logging.getLogger(__name__) + +CHOOSING, TYPING_REPLY, TYPING_CHOICE = range(3) + +reply_keyboard = [['Age', 'Favourite colour'], + ['Number of siblings', 'Something else...'], + ['Done']] +markup = ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True) + + +def facts_to_str(user_data): + facts = list() + + for key, value in user_data.items(): + facts.append('{} - {}'.format(key, value)) + + return "\n".join(facts).join(['\n', '\n']) + + +def start(bot, update, user_data): + reply_text = "Hi! My name is Doctor Botter." + if user_data: + reply_text += " You already told me your {}. Why don't you tell me something more " \ + "about yourself? Or change enything I " \ + "already know.".format(", ".join(user_data.keys())) + else: + reply_text += " I will hold a more complex conversation with you. Why don't you tell me " \ + "something about yourself?" + update.message.reply_text(reply_text, reply_markup=markup) + + return CHOOSING + + +def regular_choice(bot, update, user_data): + text = update.message.text + user_data['choice'] = text + if user_data.get(text): + reply_text = 'Your {}, I already know the following ' \ + 'about that: {}'.format(text.lower(), user_data[text.lower()]) + else: + reply_text = 'Your {}? Yes, I would love to hear about that!'.format(text.lower()) + update.message.reply_text(reply_text) + + return TYPING_REPLY + + +def custom_choice(bot, update): + update.message.reply_text('Alright, please send me the category first, ' + 'for example "Most impressive skill"') + + return TYPING_CHOICE + + +def received_information(bot, update, user_data): + text = update.message.text + category = user_data['choice'] + user_data[category] = text.lower() + del user_data['choice'] + + update.message.reply_text("Neat! Just so you know, this is what you already told me:" + "{}" + "You can tell me more, or change your opinion on " + "something.".format(facts_to_str(user_data)), reply_markup=markup) + + return CHOOSING + + +def show_data(bot, update, user_data): + update.message.reply_text("This is what you already told me:" + "{}".format(facts_to_str(user_data))) + + +def done(bot, update, user_data): + if 'choice' in user_data: + del user_data['choice'] + + update.message.reply_text("I learned these facts about you:" + "{}" + "Until next time!".format(facts_to_str(user_data))) + return ConversationHandler.END + + +def error(bot, update, error): + """Log Errors caused by Updates.""" + logger.warning('Update "%s" caused error "%s"', update, error) + + +def main(): + # Create the Updater and pass it your bot's token. + pp = PicklePersistence(filename='conversationbot') + updater = Updater("TOKEN", persistence=pp) + + # Get the dispatcher to register handlers + dp = updater.dispatcher + + # Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY + conv_handler = ConversationHandler( + entry_points=[CommandHandler('start', start, pass_user_data=True)], + + states={ + CHOOSING: [RegexHandler('^(Age|Favourite colour|Number of siblings)$', + regular_choice, + pass_user_data=True), + RegexHandler('^Something else...$', + custom_choice), + ], + + TYPING_CHOICE: [MessageHandler(Filters.text, + regular_choice, + pass_user_data=True), + ], + + TYPING_REPLY: [MessageHandler(Filters.text, + received_information, + pass_user_data=True), + ], + }, + + fallbacks=[RegexHandler('^Done$', done, pass_user_data=True)], + name="my_conversation", + persistent=True + ) + + dp.add_handler(conv_handler) + + show_data_handler = CommandHandler('show_data', show_data, pass_user_data=True) + dp.add_handler(show_data_handler) + # log all errors + dp.add_error_handler(error) + + # Start the Bot + updater.start_polling() + + # Run the bot until you press Ctrl-C or the process receives SIGINT, + # SIGTERM or SIGABRT. This should be used most of the time, since + # start_polling() is non-blocking and will stop the bot gracefully. + updater.idle() + + +if __name__ == '__main__': + main() diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index 1c69d0b83..8b35726e9 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -18,6 +18,9 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """Extensions over the Telegram Bot API to facilitate bot making""" +from .basepersistence import BasePersistence +from .picklepersistence import PicklePersistence +from .dictpersistence import DictPersistence from .dispatcher import Dispatcher, DispatcherHandlerStop, run_async from .jobqueue import JobQueue, Job from .updater import Updater @@ -43,4 +46,5 @@ __all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler', 'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler', 'StringRegexHandler', 'TypeHandler', 'ConversationHandler', 'PreCheckoutQueryHandler', 'ShippingQueryHandler', 'MessageQueue', 'DelayQueue', - 'DispatcherHandlerStop', 'run_async') + 'DispatcherHandlerStop', 'run_async', 'BasePersistence', 'PicklePersistence', + 'DictPersistence') diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py new file mode 100644 index 000000000..f62491ee9 --- /dev/null +++ b/telegram/ext/basepersistence.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2018 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the BasePersistence class.""" + + +class BasePersistence(object): + """Interface class for adding persistence to your bot. + Subclass this object for different implementations of a persistent bot. + + All relevant methods must be overwritten. This means: + + * If :attr:`store_chat_data` is ``True`` you must overwrite :meth:`get_chat_data` and + :meth:`update_chat_data`. + * If :attr:`store_user_data` is ``True`` you must overwrite :meth:`get_user_data` and + :meth:`update_user_data`. + * If you want to store conversation data with :class:`telegram.ext.ConversationHandler`, you + must overwrite :meth:`get_conversations` and :meth:`update_conversation`. + * :meth:`flush` will be called when the bot is shutdown. + + Attributes: + store_user_data (:obj:`bool`): Optional, Whether user_data should be saved by this + persistence class. + store_chat_data (:obj:`bool`): Optional. Whether chat_data should be saved by this + persistence class. + + Args: + store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this + persistence class. Default is ``True``. + store_chat_data (:obj:`bool`, optional): Whether chat_data should be saved by this + persistence class. Default is ``True`` . + """ + + def __init__(self, store_user_data=True, store_chat_data=True): + self.store_user_data = store_user_data + self.store_chat_data = store_chat_data + + def get_user_data(self): + """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + persistence object. It should return the user_data if stored, or an empty + ``defaultdict(dict)``. + + Returns: + :obj:`defaultdict`: The restored user data. + """ + raise NotImplementedError + + def get_chat_data(self): + """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + persistence object. It should return the chat_data if stored, or an empty + ``defaultdict(dict)``. + + Returns: + :obj:`defaultdict`: The restored chat data. + """ + raise NotImplementedError + + def get_conversations(self, name): + """"Will be called by :class:`telegram.ext.Dispatcher` when a + :class:`telegram.ext.ConversationHandler` is added if + :attr:`telegram.ext.ConversationHandler.persistent` is ``True``. + It should return the conversations for the handler with `name` or an empty ``dict`` + + Args: + name (:obj:`str`): The handlers name. + + Returns: + :obj:`dict`: The restored conversations for the handler. + """ + raise NotImplementedError + + def update_conversation(self, name, key, new_state): + """Will be called when a :attr:`telegram.ext.ConversationHandler.update_state` + is called. this allows the storeage of the new state in the persistence. + + Args: + name (:obj:`str`): The handlers name. + key (:obj:`tuple`): The key the state is changed for. + new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. + """ + raise NotImplementedError + + def update_user_data(self, user_id, data): + """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + handled an update. + + Args: + user_id (:obj:`int`): The user the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id]. + """ + raise NotImplementedError + + def update_chat_data(self, chat_id, data): + """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + handled an update. + + Args: + chat_id (:obj:`int`): The chat the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[user_id]. + """ + raise NotImplementedError + + def flush(self): + """Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the + persistence a chance to finish up saving or close a database connection gracefully. If this + is not of any importance just pass will be sufficient. + """ + pass diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index da55bf7e3..b1cf8857b 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -79,6 +79,10 @@ class ConversationHandler(Handler): conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`): Optional. When this handler is inactive more than this timeout (in seconds), it will be automatically ended. If this value is 0 (default), there will be no timeout. + name (:obj:`str`): Optional. The name for this conversationhandler. Required for + persistence + persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be + saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` Args: entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can @@ -113,6 +117,10 @@ class ConversationHandler(Handler): conversation_timeout (:obj:`float`|:obj:`datetime.timedelta`, optional): When this handler is inactive more than this timeout (in seconds), it will be automatically ended. If this value is 0 or None (default), there will be no timeout. + name (:obj:`str`, optional): The name for this conversationhandler. Required for + persistence + persistent (:obj:`bool`, optional): If the conversations dict for this handler should be + saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` Raises: ValueError @@ -131,7 +139,9 @@ class ConversationHandler(Handler): per_chat=True, per_user=True, per_message=False, - conversation_timeout=None): + conversation_timeout=None, + name=None, + persistent=False): self.entry_points = entry_points self.states = states @@ -144,6 +154,13 @@ class ConversationHandler(Handler): self.per_chat = per_chat self.per_message = per_message self.conversation_timeout = conversation_timeout + self.name = name + if persistent and not self.name: + raise ValueError("Conversations can't be persistent when handler is unnamed.") + self.persistent = persistent + self.persistence = None + """:obj:`telegram.ext.BasePersistance`: The persistence used to store conversations. + Set by dispatcher""" self.timeout_jobs = dict() self.conversations = dict() @@ -318,14 +335,21 @@ class ConversationHandler(Handler): if new_state == self.END: if key in self.conversations: del self.conversations[key] + if self.persistent: + self.persistence.update_conversation(self.name, key, None) else: pass elif isinstance(new_state, Promise): self.conversations[key] = (self.conversations.get(key), new_state) + if self.persistent: + self.persistence.update_conversation(self.name, key, + (self.conversations.get(key), new_state)) elif new_state is not None: self.conversations[key] = new_state + if self.persistent: + self.persistence.update_conversation(self.name, key, new_state) def _trigger_timeout(self, bot, job): del self.timeout_jobs[job.context] diff --git a/telegram/ext/dictpersistence.py b/telegram/ext/dictpersistence.py new file mode 100644 index 000000000..28ca47273 --- /dev/null +++ b/telegram/ext/dictpersistence.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2018 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the DictPersistence class.""" +from telegram.utils.helpers import decode_user_chat_data_from_json,\ + decode_conversations_from_json, enocde_conversations_to_json + +try: + import ujson as json +except ImportError: + import json +from collections import defaultdict +from telegram.ext import BasePersistence + + +class DictPersistence(BasePersistence): + """Using python's dicts and json for making you bot persistent. + + Attributes: + store_user_data (:obj:`bool`): Whether user_data should be saved by this + persistence class. + store_chat_data (:obj:`bool`): Whether chat_data should be saved by this + persistence class. + + Args: + store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this + persistence class. Default is ``True``. + store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this + persistence class. Default is ``True``. + user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct + user_data on creating this persistence. Default is ``""``. + chat_data_json (:obj:`str`, optional): Json string that will be used to reconstruct + chat_data on creating this persistence. Default is ``""``. + conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct + conversation on creating this persistence. Default is ``""``. + """ + + def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='', + chat_data_json='', conversations_json=''): + self.store_user_data = store_user_data + self.store_chat_data = store_chat_data + self._user_data = None + self._chat_data = None + self._conversations = None + self._user_data_json = None + self._chat_data_json = None + self._conversations_json = None + if user_data_json: + try: + self._user_data = decode_user_chat_data_from_json(user_data_json) + self._user_data_json = user_data_json + except (ValueError, AttributeError): + raise TypeError("Unable to deserialize user_data_json. Not valid JSON") + if chat_data_json: + try: + self._chat_data = decode_user_chat_data_from_json(chat_data_json) + self._chat_data_json = chat_data_json + except (ValueError, AttributeError): + raise TypeError("Unable to deserialize chat_data_json. Not valid JSON") + + if conversations_json: + try: + self._conversations = decode_conversations_from_json(conversations_json) + self._conversations_json = conversations_json + except (ValueError, AttributeError): + raise TypeError("Unable to deserialize conversations_json. Not valid JSON") + + @property + def user_data(self): + """:obj:`dict`: The user_data as a dict""" + return self._user_data + + @property + def user_data_json(self): + """:obj:`str`: The user_data serialized as a JSON-string.""" + if self._user_data_json: + return self._user_data_json + else: + return json.dumps(self.user_data) + + @property + def chat_data(self): + """:obj:`dict`: The chat_data as a dict""" + return self._chat_data + + @property + def chat_data_json(self): + """:obj:`str`: The chat_data serialized as a JSON-string.""" + if self._chat_data_json: + return self._chat_data_json + else: + return json.dumps(self.chat_data) + + @property + def conversations(self): + """:obj:`dict`: The conversations as a dict""" + return self._conversations + + @property + def conversations_json(self): + """:obj:`str`: The conversations serialized as a JSON-string.""" + if self._conversations_json: + return self._conversations_json + else: + return enocde_conversations_to_json(self.conversations) + + def get_user_data(self): + """Returns the user_data created from the ``user_data_json`` or an empty defaultdict. + + Returns: + :obj:`defaultdict`: The restored user data. + """ + if self.user_data: + pass + else: + self._user_data = defaultdict(dict) + return self.user_data.copy() + + def get_chat_data(self): + """Returns the chat_data created from the ``chat_data_json`` or an empty defaultdict. + + Returns: + :obj:`defaultdict`: The restored user data. + """ + if self.chat_data: + pass + else: + self._chat_data = defaultdict(dict) + return self.chat_data.copy() + + def get_conversations(self, name): + """Returns the conversations created from the ``conversations_json`` or an empty + defaultdict. + + Returns: + :obj:`defaultdict`: The restored user data. + """ + if self.conversations: + pass + else: + self._conversations = {} + return self.conversations.get(name, {}).copy() + + def update_conversation(self, name, key, new_state): + """Will update the conversations for the given handler. + + Args: + name (:obj:`str`): The handlers name. + key (:obj:`tuple`): The key the state is changed for. + new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. + """ + if self._conversations.setdefault(name, {}).get(key) == new_state: + return + self._conversations[name][key] = new_state + self._conversations_json = None + + def update_user_data(self, user_id, data): + """Will update the user_data (if changed). + + Args: + user_id (:obj:`int`): The user the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id]. + """ + if self._user_data.get(user_id) == data: + return + self._user_data[user_id] = data + self._user_data_json = None + + def update_chat_data(self, chat_id, data): + """Will update the chat_data (if changed). + + Args: + chat_id (:obj:`int`): The chat the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id]. + """ + if self._chat_data.get(chat_id) == data: + return + self._chat_data[chat_id] = data + self._chat_data_json = None diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index 0eeead9e9..642095381 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -33,6 +33,7 @@ from future.builtins import range from telegram import TelegramError from telegram.ext.handler import Handler from telegram.utils.promise import Promise +from telegram.ext import BasePersistence logging.getLogger(__name__).addHandler(logging.NullHandler()) DEFAULT_GROUP = 0 @@ -48,6 +49,7 @@ def run_async(func): Note: Use this decorator to run handlers asynchronously. """ + @wraps(func) def async_func(*args, **kwargs): return Dispatcher.get_instance().run_async(func, *args, **kwargs) @@ -70,6 +72,10 @@ class Dispatcher(object): instance to pass onto handler callbacks. workers (:obj:`int`): Number of maximum concurrent worker threads for the ``@run_async`` decorator. + user_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the user. + chat_data (:obj:`defaultdict`): A dictionary handlers can use to store data for the chat. + persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to + store data that should be persistent over restarts Args: bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. @@ -78,6 +84,8 @@ class Dispatcher(object): instance to pass onto handler callbacks. workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the ``@run_async`` decorator. defaults to 4. + persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to + store data that should be persistent over restarts """ @@ -86,16 +94,30 @@ class Dispatcher(object): __singleton = None logger = logging.getLogger(__name__) - def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None): + def __init__(self, bot, update_queue, workers=4, exception_event=None, job_queue=None, + persistence=None): self.bot = bot self.update_queue = update_queue - self.job_queue = job_queue self.workers = workers - self.user_data = defaultdict(dict) - """:obj:`dict`: A dictionary handlers can use to store data for the user.""" self.chat_data = defaultdict(dict) - """:obj:`dict`: A dictionary handlers can use to store data for the chat.""" + if persistence: + if not isinstance(persistence, BasePersistence): + raise TypeError("persistence should be based on telegram.ext.BasePersistence") + self.persistence = persistence + if self.persistence.store_user_data: + self.user_data = self.persistence.get_user_data() + if not isinstance(self.user_data, defaultdict): + raise ValueError("user_data must be of type defaultdict") + if self.persistence.store_chat_data: + self.chat_data = self.persistence.get_chat_data() + if not isinstance(self.chat_data, defaultdict): + raise ValueError("chat_data must be of type defaultdict") + else: + self.persistence = None + + self.job_queue = job_queue + self.handlers = {} """Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]: Holds the handlers per group.""" self.groups = [] @@ -277,6 +299,19 @@ class Dispatcher(object): try: for handler in (x for x in self.handlers[group] if x.check_update(update)): handler.handle_update(update, self) + if self.persistence: + if self.persistence.store_chat_data and update.effective_chat.id: + chat_id = update.effective_chat.id + try: + self.persistence.update_chat_data(chat_id, self.chat_data[chat_id]) + except Exception: + self.logger.exception('Saving chat data raised an error') + if self.persistence.store_user_data and update.effective_user.id: + user_id = update.effective_user.id + try: + self.persistence.update_user_data(user_id, self.user_data[user_id]) + except Exception: + self.logger.exception('Saving user data raised an error') break # Stop processing with any other handler. @@ -324,11 +359,20 @@ class Dispatcher(object): group (:obj:`int`, optional): The group identifier. Default is 0. """ + # Unfortunately due to circular imports this has to be here + from .conversationhandler import ConversationHandler if not isinstance(handler, Handler): raise TypeError('handler is not an instance of {0}'.format(Handler.__name__)) if not isinstance(group, int): raise TypeError('group is not int') + if isinstance(handler, ConversationHandler) and handler.persistent: + if not self.persistence: + raise ValueError( + "Conversationhandler {} can not be persistent if dispatcher has no " + "persistence".format(handler.name)) + handler.conversations = self.persistence.get_conversations(handler.name) + handler.persistence = self.persistence if group not in self.handlers: self.handlers[group] = list() diff --git a/telegram/ext/picklepersistence.py b/telegram/ext/picklepersistence.py new file mode 100644 index 000000000..ed3c06cf9 --- /dev/null +++ b/telegram/ext/picklepersistence.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2018 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the PicklePersistence class.""" +import pickle +from collections import defaultdict + +from telegram.ext import BasePersistence + + +class PicklePersistence(BasePersistence): + """Using python's builtin pickle for making you bot persistent. + + Attributes: + filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file` + is false this will be used as a prefix. + store_user_data (:obj:`bool`): Optional. Whether user_data should be saved by this + persistence class. + store_chat_data (:obj:`bool`): Optional. Whether user_data should be saved by this + persistence class. + single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of + `filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is + ``True``. + on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush` + is called and keep data in memory until that happens. When False will store data on any + transaction. Default is ``False``. + + Args: + filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file` + is false this will be used as a prefix. + store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this + persistence class. Default is ``True``. + store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this + persistence class. Default is ``True``. + single_file (:obj:`bool`, optional): When ``False`` will store 3 sperate files of + `filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is + ``True``. + on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush` + is called and keep data in memory until that happens. When False will store data on any + transaction. Default is ``False``. + """ + + def __init__(self, filename, store_user_data=True, store_chat_data=True, singe_file=True, + on_flush=False): + self.filename = filename + self.store_user_data = store_user_data + self.store_chat_data = store_chat_data + self.single_file = singe_file + self.on_flush = on_flush + self.user_data = None + self.chat_data = None + self.conversations = None + + def load_singlefile(self): + try: + filename = self.filename + with open(self.filename, "rb") as f: + all = pickle.load(f) + self.user_data = defaultdict(dict, all['user_data']) + self.chat_data = defaultdict(dict, all['chat_data']) + self.conversations = all['conversations'] + except IOError: + self.conversations = {} + self.user_data = defaultdict(dict) + self.chat_data = defaultdict(dict) + except pickle.UnpicklingError: + raise TypeError("File {} does not contain valid pickle data".format(filename)) + except Exception: + raise TypeError("Something went wrong unpickling {}".format(filename)) + + def load_file(self, filename): + try: + with open(filename, "rb") as f: + return pickle.load(f) + except IOError: + return None + except pickle.UnpicklingError: + raise TypeError("File {} does not contain valid pickle data".format(filename)) + except Exception: + raise TypeError("Something went wrong unpickling {}".format(filename)) + + def dump_singlefile(self): + with open(self.filename, "wb") as f: + all = {'conversations': self.conversations, 'user_data': self.user_data, + 'chat_data': self.chat_data} + pickle.dump(all, f) + + def dump_file(self, filename, data): + with open(filename, "wb") as f: + pickle.dump(data, f) + + def get_user_data(self): + """Returns the user_data from the pickle file if it exsists or an empty defaultdict. + + Returns: + :obj:`defaultdict`: The restored user data. + """ + if self.user_data: + pass + elif not self.single_file: + filename = "{}_user_data".format(self.filename) + data = self.load_file(filename) + if not data: + data = defaultdict(dict) + else: + data = defaultdict(dict, data) + self.user_data = data + else: + self.load_singlefile() + return self.user_data.copy() + + def get_chat_data(self): + """Returns the chat_data from the pickle file if it exsists or an empty defaultdict. + + Returns: + :obj:`defaultdict`: The restored chat data. + """ + if self.chat_data: + pass + elif not self.single_file: + filename = "{}_chat_data".format(self.filename) + data = self.load_file(filename) + if not data: + data = defaultdict(dict) + else: + data = defaultdict(dict, data) + self.chat_data = data + else: + self.load_singlefile() + return self.chat_data.copy() + + def get_conversations(self, name): + """Returns the conversations from the pickle file if it exsists or an empty defaultdict. + + Args: + name (:obj:`str`): The handlers name. + + Returns: + :obj:`dict`: The restored conversations for the handler. + """ + if self.conversations: + pass + elif not self.single_file: + filename = "{}_conversations".format(self.filename) + data = self.load_file(filename) + if not data: + data = {name: {}} + self.conversations = data + else: + self.load_singlefile() + return self.conversations.get(name, {}).copy() + + def update_conversation(self, name, key, new_state): + """Will update the conversations for the given handler and depending on :attr:`on_flush` + save the pickle file. + + Args: + name (:obj:`str`): The handlers name. + key (:obj:`tuple`): The key the state is changed for. + new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. + """ + if self.conversations.setdefault(name, {}).get(key) == new_state: + return + self.conversations[name][key] = new_state + if not self.on_flush: + if not self.single_file: + filename = "{}_conversations".format(self.filename) + self.dump_file(filename, self.conversations) + else: + self.dump_singlefile() + + def update_user_data(self, user_id, data): + """Will update the user_data (if changed) and depending on :attr:`on_flush` save the + pickle file. + + Args: + user_id (:obj:`int`): The user the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id]. + """ + if self.user_data.get(user_id) == data: + return + self.user_data[user_id] = data + if not self.on_flush: + if not self.single_file: + filename = "{}_user_data".format(self.filename) + self.dump_file(filename, self.user_data) + else: + self.dump_singlefile() + + def update_chat_data(self, chat_id, data): + """Will update the chat_data (if changed) and depending on :attr:`on_flush` save the + pickle file. + + Args: + chat_id (:obj:`int`): The chat the data might have been changed for. + data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id]. + """ + if self.chat_data.get(chat_id) == data: + return + self.chat_data[chat_id] = data + if not self.on_flush: + if not self.single_file: + filename = "{}_chat_data".format(self.filename) + self.dump_file(filename, self.chat_data) + else: + self.dump_singlefile() + + def flush(self): + """If :attr:`on_flush` is set to ``True``. Will save all data in memory to pickle file(s). If + it's ``False`` will just pass. + """ + if not self.on_flush: + pass + else: + if self.single_file: + self.dump_singlefile() + else: + self.dump_file("{}_user_data".format(self.filename), self.user_data) + self.dump_file("{}_chat_data".format(self.filename), self.chat_data) + self.dump_file("{}_conversations".format(self.filename), self.conversations) diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index 96dfdfa06..3f050e042 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -55,6 +55,8 @@ class Updater(object): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that handles the updates and dispatches them to the handlers. running (:obj:`bool`): Indicates if the updater is running. + persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to + store data that should be persistent over restarts. Args: token (:obj:`str`, optional): The bot's token given by the @BotFather. @@ -73,6 +75,8 @@ class Updater(object): `telegram.utils.request.Request` object (ignored if `bot` argument is used). The request_kwargs are very useful for the advanced users who would like to control the default timeouts and/or control the proxy used for http communication. + persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence class to + store data that should be persistent over restarts. Note: You must supply either a :attr:`bot` or a :attr:`token` argument. @@ -92,7 +96,8 @@ class Updater(object): private_key=None, private_key_password=None, user_sig_handler=None, - request_kwargs=None): + request_kwargs=None, + persistence=None): if (token is None) and (bot is None): raise ValueError('`token` or `bot` must be passed') @@ -129,12 +134,14 @@ class Updater(object): self.update_queue = Queue() self.job_queue = JobQueue(self.bot) self.__exception_event = Event() + self.persistence = persistence self.dispatcher = Dispatcher( self.bot, self.update_queue, job_queue=self.job_queue, workers=workers, - exception_event=self.__exception_event) + exception_event=self.__exception_event, + persistence=self.persistence) self.last_update_id = 0 self.running = False self.is_idle = False @@ -485,6 +492,8 @@ class Updater(object): if self.running: self.logger.info('Received signal {} ({}), stopping...'.format( signum, get_signal_name(signum))) + if self.persistence: + self.persistence.flush() self.stop() if self.user_sig_handler: self.user_sig_handler(signum, frame) diff --git a/telegram/utils/helpers.py b/telegram/utils/helpers.py index 029057951..740c19ff1 100644 --- a/telegram/utils/helpers.py +++ b/telegram/utils/helpers.py @@ -17,6 +17,12 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains helper functions.""" +from collections import defaultdict + +try: + import ujson as json +except ImportError: + import json from html import escape import re @@ -139,3 +145,65 @@ def effective_message_type(entity): return i return None + + +def enocde_conversations_to_json(conversations): + """Helper method to encode a conversations dict (that uses tuples as keys) to a + JSON-serializable way. Use :attr:`_decode_conversations_from_json` to decode. + + Args: + conversations (:obj:`dict`): The conversations dict to transofrm to JSON. + + Returns: + :obj:`str`: The JSON-serialized conversations dict + """ + tmp = {} + for handler, states in conversations.items(): + tmp[handler] = {} + for key, state in states.items(): + tmp[handler][json.dumps(key)] = state + return json.dumps(tmp) + + +def decode_conversations_from_json(json_string): + """Helper method to decode a conversations dict (that uses tuples as keys) from a + JSON-string created with :attr:`_encode_conversations_to_json`. + + Args: + json_string (:obj:`str`): The conversations dict as JSON string. + + Returns: + :obj:`dict`: The conversations dict after decoding + """ + tmp = json.loads(json_string) + conversations = {} + for handler, states in tmp.items(): + conversations[handler] = {} + for key, state in states.items(): + conversations[handler][tuple(json.loads(key))] = state + return conversations + + +def decode_user_chat_data_from_json(data): + """Helper method to decode chat or user data (that uses ints as keys) from a + JSON-string. + + Args: + data (:obj:`str`): The user/chat_data dict as JSON string. + + Returns: + :obj:`dict`: The user/chat_data defaultdict after decoding + """ + + tmp = defaultdict(dict) + decoded_data = json.loads(data) + for user, data in decoded_data.items(): + user = int(user) + tmp[user] = {} + for key, value in data.items(): + try: + key = int(key) + except ValueError: + pass + tmp[user][key] = value + return tmp diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index 7f531ff36..a415d4d05 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -96,6 +96,12 @@ class TestConversationHandler(object): ConversationHandler(self.entry_points, self.states, self.fallbacks, per_chat=False, per_user=False, per_message=False) + def test_name_and_persistent(self, dp): + with pytest.raises(ValueError, match="when handler is unnamed"): + dp.add_handler(ConversationHandler([], {}, [], persistent=True)) + c = ConversationHandler([], {}, [], name="handler", persistent=True) + assert c.name == "handler" + def test_conversation_handler(self, dp, bot, user1, user2): handler = ConversationHandler(entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 6dc6b236f..b184558f0 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -82,6 +82,16 @@ class TestDispatcher(object): sleep(.1) assert self.received is None + def test_construction_with_bad_persistence(self, caplog, bot): + class my_per: + def __init__(self): + self.store_user_data = False + self.store_chat_data = False + + with pytest.raises(TypeError, + match='persistence should be based on telegram.ext.BasePersistence'): + Dispatcher(bot, None, persistence=my_per()) + def test_error_handler_that_raises_errors(self, dp): """ Make sure that errors raised in error handlers don't break the main loop of the dispatcher diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 000000000..97d0ccaec --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,751 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2018 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +from telegram.utils.helpers import enocde_conversations_to_json + +try: + import ujson as json +except ImportError: + import json +import logging +import os +import pickle +from collections import defaultdict + +import pytest + +from telegram import Update, Message, User, Chat +from telegram.ext import BasePersistence, Updater, ConversationHandler, MessageHandler, Filters, \ + PicklePersistence, CommandHandler, DictPersistence + + +@pytest.fixture(scope="function") +def base_persistence(): + return BasePersistence(store_chat_data=True, store_user_data=True) + + +@pytest.fixture(scope="function") +def chat_data(): + return defaultdict(dict, {-12345: {'test1': 'test2'}, -67890: {3: 'test4'}}) + + +@pytest.fixture(scope="function") +def user_data(): + return defaultdict(dict, {12345: {'test1': 'test2'}, 67890: {3: 'test4'}}) + + +@pytest.fixture(scope='function') +def conversations(): + return {'name1': {(123, 123): 3, (456, 654): 4}, + 'name2': {(123, 321): 1, (890, 890): 2}} + + +@pytest.fixture(scope="function") +def updater(bot, base_persistence): + base_persistence.store_chat_data = False + base_persistence.store_user_data = False + u = Updater(bot=bot, persistence=base_persistence) + base_persistence.store_chat_data = True + base_persistence.store_user_data = True + return u + + +class TestBasePersistence(object): + + def test_creation(self, base_persistence): + assert base_persistence.store_chat_data + assert base_persistence.store_user_data + with pytest.raises(NotImplementedError): + base_persistence.get_chat_data() + with pytest.raises(NotImplementedError): + base_persistence.get_user_data() + with pytest.raises(NotImplementedError): + base_persistence.get_conversations("test") + with pytest.raises(NotImplementedError): + base_persistence.update_chat_data(None, None) + with pytest.raises(NotImplementedError): + base_persistence.update_user_data(None, None) + with pytest.raises(NotImplementedError): + base_persistence.update_conversation(None, None, None) + + def test_implementation(self, updater, base_persistence): + dp = updater.dispatcher + assert dp.persistence == base_persistence + + def test_conversationhandler_addition(self, dp, base_persistence): + with pytest.raises(ValueError, match="when handler is unnamed"): + ConversationHandler([], [], [], persistent=True) + with pytest.raises(ValueError, match="if dispatcher has no persistence"): + dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) + dp.persistence = base_persistence + with pytest.raises(NotImplementedError): + dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) + + def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data): + def get_user_data(): + return "test" + + def get_chat_data(): + return "test" + + base_persistence.get_user_data = get_user_data + base_persistence.get_chat_data = get_chat_data + with pytest.raises(ValueError, match="user_data must be of type defaultdict"): + u = Updater(bot=bot, persistence=base_persistence) + + def get_user_data(): + return user_data + + base_persistence.get_user_data = get_user_data + with pytest.raises(ValueError, match="chat_data must be of type defaultdict"): + u = Updater(bot=bot, persistence=base_persistence) + + def get_chat_data(): + return chat_data + + base_persistence.get_chat_data = get_chat_data + u = Updater(bot=bot, persistence=base_persistence) + assert u.dispatcher.chat_data == chat_data + assert u.dispatcher.user_data == user_data + u.dispatcher.chat_data[442233]['test5'] = 'test6' + assert u.dispatcher.chat_data[442233]['test5'] == 'test6' + + def test_dispatcher_integration_handlers(self, caplog, bot, base_persistence, + chat_data, user_data): + def get_user_data(): + return user_data + + def get_chat_data(): + return chat_data + + base_persistence.get_user_data = get_user_data + base_persistence.get_chat_data = get_chat_data + # base_persistence.update_chat_data = lambda x: x + # base_persistence.update_user_data = lambda x: x + updater = Updater(bot=bot, persistence=base_persistence) + dp = updater.dispatcher + + def callback_known_user(bot, update, user_data, chat_data): + if not user_data['test1'] == 'test2': + pytest.fail('user_data corrupt') + + def callback_known_chat(bot, update, user_data, chat_data): + if not chat_data['test3'] == 'test4': + pytest.fail('chat_data corrupt') + + def callback_unknown_user_or_chat(bot, update, user_data, chat_data): + if not user_data == {}: + pytest.fail('user_data corrupt') + if not chat_data == {}: + pytest.fail('chat_data corrupt') + user_data[1] = 'test7' + chat_data[2] = 'test8' + + known_user = MessageHandler(Filters.user(user_id=12345), callback_known_user, + pass_chat_data=True, pass_user_data=True) + known_chat = MessageHandler(Filters.chat(chat_id=-67890), callback_known_chat, + pass_chat_data=True, pass_user_data=True) + unknown = MessageHandler(Filters.all, callback_unknown_user_or_chat, pass_chat_data=True, + pass_user_data=True) + dp.add_handler(known_user) + dp.add_handler(known_chat) + dp.add_handler(unknown) + user1 = User(id=12345, first_name='test user', is_bot=False) + user2 = User(id=54321, first_name='test user', is_bot=False) + chat1 = Chat(id=-67890, type='group') + chat2 = Chat(id=-987654, type='group') + m = Message(1, user1, None, chat2) + u = Update(0, m) + with caplog.at_level(logging.ERROR): + dp.process_update(u) + rec = caplog.records[-1] + assert rec.msg == 'Saving user data raised an error' + assert rec.levelname == 'ERROR' + rec = caplog.records[-2] + assert rec.msg == 'Saving chat data raised an error' + assert rec.levelname == 'ERROR' + m.from_user = user2 + m.chat = chat1 + u = Update(1, m) + dp.process_update(u) + m.chat = chat2 + u = Update(2, m) + + def save_chat_data(data): + if -987654 not in data: + pytest.fail() + + def save_user_data(data): + if 54321 not in data: + pytest.fail() + + base_persistence.update_chat_data = save_chat_data + base_persistence.update_user_data = save_user_data + dp.process_update(u) + + assert dp.user_data[54321][1] == 'test7' + assert dp.chat_data[-987654][2] == 'test8' + + +@pytest.fixture(scope='function') +def pickle_persistence(): + return PicklePersistence(filename='pickletest', + store_user_data=True, + store_chat_data=True, + singe_file=False, + on_flush=False) + + +@pytest.fixture(scope='function') +def bad_pickle_files(): + for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations', + 'pickletest']: + with open(name, 'w') as f: + f.write('(())') + yield True + for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations', + 'pickletest']: + os.remove(name) + + +@pytest.fixture(scope='function') +def good_pickle_files(user_data, chat_data, conversations): + all = {'user_data': user_data, 'chat_data': chat_data, 'conversations': conversations} + with open('pickletest_user_data', 'wb') as f: + pickle.dump(user_data, f) + with open('pickletest_chat_data', 'wb') as f: + pickle.dump(chat_data, f) + with open('pickletest_conversations', 'wb') as f: + pickle.dump(conversations, f) + with open('pickletest', 'wb') as f: + pickle.dump(all, f) + yield True + for name in ['pickletest_user_data', 'pickletest_chat_data', 'pickletest_conversations', + 'pickletest']: + os.remove(name) + + +@pytest.fixture(scope='function') +def update(bot): + user = User(id=321, first_name='test_user', is_bot=False) + chat = Chat(id=123, type='group') + message = Message(1, user, None, chat, text="Hi there", bot=bot) + return Update(0, message=message) + + +class TestPickelPersistence(object): + def test_no_files_present_multi_file(self, pickle_persistence): + assert pickle_persistence.get_user_data() == defaultdict(dict) + assert pickle_persistence.get_user_data() == defaultdict(dict) + assert pickle_persistence.get_chat_data() == defaultdict(dict) + assert pickle_persistence.get_chat_data() == defaultdict(dict) + assert pickle_persistence.get_conversations('noname') == {} + assert pickle_persistence.get_conversations('noname') == {} + + def test_no_files_present_single_file(self, pickle_persistence): + pickle_persistence.single_file = True + assert pickle_persistence.get_user_data() == defaultdict(dict) + assert pickle_persistence.get_chat_data() == defaultdict(dict) + assert pickle_persistence.get_conversations('noname') == {} + + def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files): + with pytest.raises(TypeError, match='pickletest_user_data'): + pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest_chat_data'): + pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest_conversations'): + pickle_persistence.get_conversations('name') + + def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files): + pickle_persistence.single_file = True + with pytest.raises(TypeError, match='pickletest'): + pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest'): + pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest'): + pickle_persistence.get_conversations('name') + + def test_with_good_multi_file(self, pickle_persistence, good_pickle_files): + user_data = pickle_persistence.get_user_data() + assert isinstance(user_data, defaultdict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + assert user_data[54321] == {} + + chat_data = pickle_persistence.get_chat_data() + assert isinstance(chat_data, defaultdict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + assert chat_data[-54321] == {} + + conversation1 = pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + def test_with_good_single_file(self, pickle_persistence, good_pickle_files): + pickle_persistence.single_file = True + user_data = pickle_persistence.get_user_data() + assert isinstance(user_data, defaultdict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + assert user_data[54321] == {} + + chat_data = pickle_persistence.get_chat_data() + assert isinstance(chat_data, defaultdict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + assert chat_data[-54321] == {} + + conversation1 = pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + def test_updating_multi_file(self, pickle_persistence, good_pickle_files): + user_data = pickle_persistence.get_user_data() + user_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.user_data == user_data + pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + with open('pickletest_user_data', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)) + assert user_data_test == user_data + + chat_data = pickle_persistence.get_chat_data() + chat_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.chat_data == chat_data + pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + with open('pickletest_chat_data', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)) + assert chat_data_test == chat_data + + conversation1 = pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + with open('pickletest_conversations', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)) + assert conversations_test['name1'] == conversation1 + + def test_updating_single_file(self, pickle_persistence, good_pickle_files): + pickle_persistence.single_file = True + + user_data = pickle_persistence.get_user_data() + user_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.user_data == user_data + pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + with open('pickletest', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)['user_data']) + assert user_data_test == user_data + + chat_data = pickle_persistence.get_chat_data() + chat_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.chat_data == chat_data + pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + with open('pickletest', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) + assert chat_data_test == chat_data + + conversation1 = pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + with open('pickletest', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)['conversations']) + assert conversations_test['name1'] == conversation1 + + def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files): + # Should run without error + pickle_persistence.flush() + pickle_persistence.on_flush = True + + user_data = pickle_persistence.get_user_data() + user_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.user_data == user_data + + pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + + with open('pickletest_user_data', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)) + assert not user_data_test == user_data + + chat_data = pickle_persistence.get_chat_data() + chat_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.chat_data == chat_data + + pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + + with open('pickletest_chat_data', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)) + assert not chat_data_test == chat_data + + conversation1 = pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + + pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + + with open('pickletest_conversations', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)) + assert not conversations_test['name1'] == conversation1 + + pickle_persistence.flush() + with open('pickletest_user_data', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)) + assert user_data_test == user_data + + with open('pickletest_chat_data', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)) + assert chat_data_test == chat_data + + with open('pickletest_conversations', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)) + assert conversations_test['name1'] == conversation1 + + def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files): + # Should run without error + pickle_persistence.flush() + + pickle_persistence.on_flush = True + pickle_persistence.single_file = True + + user_data = pickle_persistence.get_user_data() + user_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.user_data == user_data + pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + with open('pickletest', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)['user_data']) + assert not user_data_test == user_data + + chat_data = pickle_persistence.get_chat_data() + chat_data[54321]['test9'] = 'test 10' + assert not pickle_persistence.chat_data == chat_data + pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + with open('pickletest', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) + assert not chat_data_test == chat_data + + conversation1 = pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + with open('pickletest', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)['conversations']) + assert not conversations_test['name1'] == conversation1 + + pickle_persistence.flush() + with open('pickletest', 'rb') as f: + user_data_test = defaultdict(dict, pickle.load(f)['user_data']) + assert user_data_test == user_data + + with open('pickletest', 'rb') as f: + chat_data_test = defaultdict(dict, pickle.load(f)['chat_data']) + assert chat_data_test == chat_data + + with open('pickletest', 'rb') as f: + conversations_test = defaultdict(dict, pickle.load(f)['conversations']) + assert conversations_test['name1'] == conversation1 + + def test_with_handler(self, bot, update, pickle_persistence, good_pickle_files): + u = Updater(bot=bot, persistence=pickle_persistence) + dp = u.dispatcher + + def first(bot, update, user_data, chat_data): + if not user_data == {}: + pytest.fail() + if not chat_data == {}: + pytest.fail() + user_data['test1'] = 'test2' + chat_data['test3'] = 'test4' + + def second(bot, update, user_data, chat_data): + if not user_data['test1'] == 'test2': + pytest.fail() + if not chat_data['test3'] == 'test4': + pytest.fail() + + h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True) + h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True) + dp.add_handler(h1) + dp.process_update(update) + del (dp) + del (u) + del (pickle_persistence) + pickle_persistence_2 = PicklePersistence(filename='pickletest', + store_user_data=True, + store_chat_data=True, + singe_file=False, + on_flush=False) + u = Updater(bot=bot, persistence=pickle_persistence_2) + dp = u.dispatcher + dp.add_handler(h2) + dp.process_update(update) + + def test_with_conversationHandler(self, dp, update, good_pickle_files, pickle_persistence): + dp.persistence = pickle_persistence + NEXT, NEXT2 = range(2) + + def start(bot, update): + return NEXT + + start = CommandHandler('start', start) + + def next(bot, update): + return NEXT2 + + next = MessageHandler(None, next) + + def next2(bot, update): + return ConversationHandler.END + + next2 = MessageHandler(None, next2) + + ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2', + persistent=True) + dp.add_handler(ch) + assert ch.conversations[ch._get_key(update)] == 1 + dp.process_update(update) + assert ch._get_key(update) not in ch.conversations + update.message.text = '/start' + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 0 + assert ch.conversations == pickle_persistence.conversations['name2'] + + @classmethod + def teardown_class(cls): + try: + for name in ['pickletest_user_data', 'pickletest_chat_data', + 'pickletest_conversations', + 'pickletest']: + os.remove(name) + except Exception: + pass + + +@pytest.fixture(scope='function') +def user_data_json(user_data): + return json.dumps(user_data) + + +@pytest.fixture(scope='function') +def chat_data_json(chat_data): + return json.dumps(chat_data) + + +@pytest.fixture(scope='function') +def conversations_json(conversations): + return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2": + {"[123, 321]": 1, "[890, 890]": 2}}""" + + +class TestDictPersistence(object): + def test_no_json_given(self): + dict_persistence = DictPersistence() + assert dict_persistence.get_user_data() == defaultdict(dict) + assert dict_persistence.get_chat_data() == defaultdict(dict) + assert dict_persistence.get_conversations('noname') == {} + + def test_bad_json_string_given(self): + bad_user_data = 'thisisnojson99900()))(' + bad_chat_data = 'thisisnojson99900()))(' + bad_conversations = 'thisisnojson99900()))(' + with pytest.raises(TypeError, match='user_data'): + DictPersistence(user_data_json=bad_user_data) + with pytest.raises(TypeError, match='chat_data'): + DictPersistence(chat_data_json=bad_chat_data) + with pytest.raises(TypeError, match='conversations'): + DictPersistence(conversations_json=bad_conversations) + + def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files): + bad_user_data = '["this", "is", "json"]' + bad_chat_data = '["this", "is", "json"]' + bad_conversations = '["this", "is", "json"]' + with pytest.raises(TypeError, match='user_data'): + DictPersistence(user_data_json=bad_user_data) + with pytest.raises(TypeError, match='chat_data'): + DictPersistence(chat_data_json=bad_chat_data) + with pytest.raises(TypeError, match='conversations'): + DictPersistence(conversations_json=bad_conversations) + + def test_good_json_input(self, user_data_json, chat_data_json, conversations_json): + dict_persistence = DictPersistence(user_data_json=user_data_json, + chat_data_json=chat_data_json, + conversations_json=conversations_json) + user_data = dict_persistence.get_user_data() + assert isinstance(user_data, defaultdict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + assert user_data[54321] == {} + + chat_data = dict_persistence.get_chat_data() + assert isinstance(chat_data, defaultdict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + assert chat_data[-54321] == {} + + conversation1 = dict_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = dict_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json, + conversations, conversations_json): + dict_persistence = DictPersistence(user_data_json=user_data_json, + chat_data_json=chat_data_json, + conversations_json=conversations_json) + assert dict_persistence.user_data == user_data + assert dict_persistence.chat_data == chat_data + assert dict_persistence.conversations == conversations + + def test_json_outputs(self, user_data_json, chat_data_json, conversations_json): + dict_persistence = DictPersistence(user_data_json=user_data_json, + chat_data_json=chat_data_json, + conversations_json=conversations_json) + assert dict_persistence.user_data_json == user_data_json + assert dict_persistence.chat_data_json == chat_data_json + assert dict_persistence.conversations_json == conversations_json + + def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json, + conversations, conversations_json): + dict_persistence = DictPersistence(user_data_json=user_data_json, + chat_data_json=chat_data_json, + conversations_json=conversations_json) + user_data_two = user_data.copy() + user_data_two.update({4: {5: 6}}) + dict_persistence.update_user_data(4, {5: 6}) + assert dict_persistence.user_data == user_data_two + assert dict_persistence.user_data_json != user_data_json + assert dict_persistence.user_data_json == json.dumps(user_data_two) + + chat_data_two = chat_data.copy() + chat_data_two.update({7: {8: 9}}) + dict_persistence.update_chat_data(7, {8: 9}) + assert dict_persistence.chat_data == chat_data_two + assert dict_persistence.chat_data_json != chat_data_json + assert dict_persistence.chat_data_json == json.dumps(chat_data_two) + + conversations_two = conversations.copy() + conversations_two.update({'name3': {(1, 2): 3}}) + dict_persistence.update_conversation('name3', (1, 2), 3) + assert dict_persistence.conversations == conversations_two + assert dict_persistence.conversations_json != conversations_json + assert dict_persistence.conversations_json == enocde_conversations_to_json( + conversations_two) + + def test_with_handler(self, bot, update): + dict_persistence = DictPersistence() + u = Updater(bot=bot, persistence=dict_persistence) + dp = u.dispatcher + + def first(bot, update, user_data, chat_data): + if not user_data == {}: + pytest.fail() + if not chat_data == {}: + pytest.fail() + user_data['test1'] = 'test2' + chat_data[3] = 'test4' + + def second(bot, update, user_data, chat_data): + if not user_data['test1'] == 'test2': + pytest.fail() + if not chat_data[3] == 'test4': + pytest.fail() + + h1 = MessageHandler(None, first, pass_user_data=True, pass_chat_data=True) + h2 = MessageHandler(None, second, pass_user_data=True, pass_chat_data=True) + dp.add_handler(h1) + dp.process_update(update) + del (dp) + del (u) + user_data = dict_persistence.user_data_json + chat_data = dict_persistence.chat_data_json + del (dict_persistence) + dict_persistence_2 = DictPersistence(user_data_json=user_data, + chat_data_json=chat_data) + + u = Updater(bot=bot, persistence=dict_persistence_2) + dp = u.dispatcher + dp.add_handler(h2) + dp.process_update(update) + + def test_with_conversationHandler(self, dp, update, conversations_json): + dict_persistence = DictPersistence(conversations_json=conversations_json) + dp.persistence = dict_persistence + NEXT, NEXT2 = range(2) + + def start(bot, update): + return NEXT + + start = CommandHandler('start', start) + + def next(bot, update): + return NEXT2 + + next = MessageHandler(None, next) + + def next2(bot, update): + return ConversationHandler.END + + next2 = MessageHandler(None, next2) + + ch = ConversationHandler([start], {NEXT: [next], NEXT2: [next2]}, [], name='name2', + persistent=True) + dp.add_handler(ch) + assert ch.conversations[ch._get_key(update)] == 1 + dp.process_update(update) + assert ch._get_key(update) not in ch.conversations + update.message.text = '/start' + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 0 + assert ch.conversations == dict_persistence.conversations['name2']