diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index da1017999..409f5baf2 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -178,14 +178,14 @@ class ConversationHandler(Handler): if persistent and not self.name: raise ValueError("Conversations can't be persistent when handler is unnamed.") self.persistent = persistent - self.persistence = None + self._persistence = None """:obj:`telegram.ext.BasePersistance`: The persistence used to store conversations. Set by dispatcher""" self.map_to_parent = map_to_parent self.timeout_jobs = dict() self._timeout_jobs_lock = Lock() - self.conversations = dict() + self._conversations = dict() self._conversations_lock = Lock() self.logger = logging.getLogger(__name__) @@ -225,6 +225,32 @@ class ConversationHandler(Handler): "since inline queries have no chat context.") break + @property + def persistence(self): + return self._persistence + + @persistence.setter + def persistence(self, persistence): + self._persistence = persistence + # Set persistence for nested conversations + for handlers in self.states.values(): + for handler in handlers: + if isinstance(handler, ConversationHandler): + handler.persistence = self.persistence + + @property + def conversations(self): + return self._conversations + + @conversations.setter + def conversations(self, value): + self._conversations = value + # Set conversations for nested conversations + for handlers in self.states.values(): + for handler in handlers: + if isinstance(handler, ConversationHandler): + handler.conversations = self.persistence.get_conversations(handler.name) + def _get_key(self, update): chat = update.effective_chat user = update.effective_user diff --git a/telegram/ext/dispatcher.py b/telegram/ext/dispatcher.py index 54da82a33..ccde98a61 100644 --- a/telegram/ext/dispatcher.py +++ b/telegram/ext/dispatcher.py @@ -447,8 +447,8 @@ class Dispatcher(object): raise ValueError( "Conversationhandler {} can not be persistent if dispatcher has no " "persistence".format(handler.name)) - handler.conversations = self.persistence.get_conversations(handler.name) handler.persistence = self.persistence + handler.conversations = self.persistence.get_conversations(handler.name) if group not in self.handlers: self.handlers[group] = list() diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 7bdbf55d6..9610c9ef1 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -59,7 +59,8 @@ def user_data(): @pytest.fixture(scope='function') def conversations(): return {'name1': {(123, 123): 3, (456, 654): 4}, - 'name2': {(123, 321): 1, (890, 890): 2}} + 'name2': {(123, 321): 1, (890, 890): 2}, + 'name3': {(123, 321): 1, (890, 890): 2}} @pytest.fixture(scope="function") @@ -806,6 +807,56 @@ class TestPickelPersistence(object): assert ch.conversations[ch._get_key(update)] == 0 assert ch.conversations == pickle_persistence.conversations['name2'] + def test_with_nested_conversationHandler(self, dp, update, good_pickle_files, + pickle_persistence): + dp.persistence = pickle_persistence + dp.use_context = True + NEXT2, NEXT3 = range(1, 3) + + def start(update, context): + return NEXT2 + + start = CommandHandler('start', start) + + def next(update, context): + return NEXT2 + + next = MessageHandler(None, next) + + def next2(update, context): + return ConversationHandler.END + + next2 = MessageHandler(None, next2) + + nested_ch = ConversationHandler( + [next], + {NEXT2: [next2]}, + [], + name='name3', + persistent=True, + map_to_parent={ConversationHandler.END: ConversationHandler.END}, + ) + + ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2', + persistent=True) + dp.add_handler(ch) + assert ch.conversations[ch._get_key(update)] == 1 + assert nested_ch.conversations[nested_ch._get_key(update)] == 1 + dp.process_update(update) + assert ch._get_key(update) not in ch.conversations + assert nested_ch._get_key(update) not in nested_ch.conversations + update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 1 + assert ch.conversations == pickle_persistence.conversations['name2'] + assert nested_ch._get_key(update) not in nested_ch.conversations + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 1 + assert ch.conversations == pickle_persistence.conversations['name2'] + assert nested_ch.conversations[nested_ch._get_key(update)] == 1 + assert nested_ch.conversations == pickle_persistence.conversations['name3'] + @classmethod def teardown_class(cls): try: @@ -836,6 +887,7 @@ def bot_data_json(bot_data): @pytest.fixture(scope='function') def conversations_json(conversations): return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2": + {"[123, 321]": 1, "[890, 890]": 2}, "name3": {"[123, 321]": 1, "[890, 890]": 2}}""" @@ -964,8 +1016,8 @@ class TestDictPersistence(object): assert dict_persistence.bot_data_json == json.dumps(bot_data_two) conversations_two = conversations.copy() - conversations_two.update({'name3': {(1, 2): 3}}) - dict_persistence.update_conversation('name3', (1, 2), 3) + conversations_two.update({'name4': {(1, 2): 3}}) + dict_persistence.update_conversation('name4', (1, 2), 3) assert dict_persistence.conversations == conversations_two assert dict_persistence.conversations_json != conversations_json assert dict_persistence.conversations_json == encode_conversations_to_json( @@ -1046,3 +1098,53 @@ class TestDictPersistence(object): dp.process_update(update) assert ch.conversations[ch._get_key(update)] == 0 assert ch.conversations == dict_persistence.conversations['name2'] + + def test_with_nested_conversationHandler(self, dp, update, conversations_json): + dict_persistence = DictPersistence(conversations_json=conversations_json) + dp.persistence = dict_persistence + dp.use_context = True + NEXT2, NEXT3 = range(1, 3) + + def start(update, context): + return NEXT2 + + start = CommandHandler('start', start) + + def next(update, context): + return NEXT2 + + next = MessageHandler(None, next) + + def next2(update, context): + return ConversationHandler.END + + next2 = MessageHandler(None, next2) + + nested_ch = ConversationHandler( + [next], + {NEXT2: [next2]}, + [], + name='name3', + persistent=True, + map_to_parent={ConversationHandler.END: ConversationHandler.END}, + ) + + ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2', + persistent=True) + dp.add_handler(ch) + assert ch.conversations[ch._get_key(update)] == 1 + assert nested_ch.conversations[nested_ch._get_key(update)] == 1 + dp.process_update(update) + assert ch._get_key(update) not in ch.conversations + assert nested_ch._get_key(update) not in nested_ch.conversations + update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 1 + assert ch.conversations == dict_persistence.conversations['name2'] + assert nested_ch._get_key(update) not in nested_ch.conversations + dp.process_update(update) + assert ch.conversations[ch._get_key(update)] == 1 + assert ch.conversations == dict_persistence.conversations['name2'] + assert nested_ch.conversations[nested_ch._get_key(update)] == 1 + assert nested_ch.conversations == dict_persistence.conversations['name3']