Handler persistence for nested ConversationHandlers (#1711)

* Handler persistence for nested ConversationHandlers

* Add tests for persistence w/ nested CHs
This commit is contained in:
Bibo-Joshi 2020-02-02 22:31:56 +01:00 committed by GitHub
parent f6b663f175
commit 6d9d11b8bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 6 deletions

View file

@ -178,14 +178,14 @@ class ConversationHandler(Handler):
if persistent and not self.name: if persistent and not self.name:
raise ValueError("Conversations can't be persistent when handler is unnamed.") raise ValueError("Conversations can't be persistent when handler is unnamed.")
self.persistent = persistent self.persistent = persistent
self.persistence = None self._persistence = None
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations. """:obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
Set by dispatcher""" Set by dispatcher"""
self.map_to_parent = map_to_parent self.map_to_parent = map_to_parent
self.timeout_jobs = dict() self.timeout_jobs = dict()
self._timeout_jobs_lock = Lock() self._timeout_jobs_lock = Lock()
self.conversations = dict() self._conversations = dict()
self._conversations_lock = Lock() self._conversations_lock = Lock()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@ -225,6 +225,32 @@ class ConversationHandler(Handler):
"since inline queries have no chat context.") "since inline queries have no chat context.")
break break
@property
def persistence(self):
return self._persistence
@persistence.setter
def persistence(self, persistence):
self._persistence = persistence
# Set persistence for nested conversations
for handlers in self.states.values():
for handler in handlers:
if isinstance(handler, ConversationHandler):
handler.persistence = self.persistence
@property
def conversations(self):
return self._conversations
@conversations.setter
def conversations(self, value):
self._conversations = value
# Set conversations for nested conversations
for handlers in self.states.values():
for handler in handlers:
if isinstance(handler, ConversationHandler):
handler.conversations = self.persistence.get_conversations(handler.name)
def _get_key(self, update): def _get_key(self, update):
chat = update.effective_chat chat = update.effective_chat
user = update.effective_user user = update.effective_user

View file

@ -447,8 +447,8 @@ class Dispatcher(object):
raise ValueError( raise ValueError(
"Conversationhandler {} can not be persistent if dispatcher has no " "Conversationhandler {} can not be persistent if dispatcher has no "
"persistence".format(handler.name)) "persistence".format(handler.name))
handler.conversations = self.persistence.get_conversations(handler.name)
handler.persistence = self.persistence handler.persistence = self.persistence
handler.conversations = self.persistence.get_conversations(handler.name)
if group not in self.handlers: if group not in self.handlers:
self.handlers[group] = list() self.handlers[group] = list()

View file

@ -59,7 +59,8 @@ def user_data():
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def conversations(): def conversations():
return {'name1': {(123, 123): 3, (456, 654): 4}, return {'name1': {(123, 123): 3, (456, 654): 4},
'name2': {(123, 321): 1, (890, 890): 2}} 'name2': {(123, 321): 1, (890, 890): 2},
'name3': {(123, 321): 1, (890, 890): 2}}
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@ -806,6 +807,56 @@ class TestPickelPersistence(object):
assert ch.conversations[ch._get_key(update)] == 0 assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == pickle_persistence.conversations['name2'] assert ch.conversations == pickle_persistence.conversations['name2']
def test_with_nested_conversationHandler(self, dp, update, good_pickle_files,
pickle_persistence):
dp.persistence = pickle_persistence
dp.use_context = True
NEXT2, NEXT3 = range(1, 3)
def start(update, context):
return NEXT2
start = CommandHandler('start', start)
def next(update, context):
return NEXT2
next = MessageHandler(None, next)
def next2(update, context):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
nested_ch = ConversationHandler(
[next],
{NEXT2: [next2]},
[],
name='name3',
persistent=True,
map_to_parent={ConversationHandler.END: ConversationHandler.END},
)
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
assert nested_ch._get_key(update) not in nested_ch.conversations
update.message.text = '/start'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 1
assert ch.conversations == pickle_persistence.conversations['name2']
assert nested_ch._get_key(update) not in nested_ch.conversations
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 1
assert ch.conversations == pickle_persistence.conversations['name2']
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
assert nested_ch.conversations == pickle_persistence.conversations['name3']
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls):
try: try:
@ -836,6 +887,7 @@ def bot_data_json(bot_data):
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def conversations_json(conversations): def conversations_json(conversations):
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2": return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
{"[123, 321]": 1, "[890, 890]": 2}, "name3":
{"[123, 321]": 1, "[890, 890]": 2}}""" {"[123, 321]": 1, "[890, 890]": 2}}"""
@ -964,8 +1016,8 @@ class TestDictPersistence(object):
assert dict_persistence.bot_data_json == json.dumps(bot_data_two) assert dict_persistence.bot_data_json == json.dumps(bot_data_two)
conversations_two = conversations.copy() conversations_two = conversations.copy()
conversations_two.update({'name3': {(1, 2): 3}}) conversations_two.update({'name4': {(1, 2): 3}})
dict_persistence.update_conversation('name3', (1, 2), 3) dict_persistence.update_conversation('name4', (1, 2), 3)
assert dict_persistence.conversations == conversations_two assert dict_persistence.conversations == conversations_two
assert dict_persistence.conversations_json != conversations_json assert dict_persistence.conversations_json != conversations_json
assert dict_persistence.conversations_json == encode_conversations_to_json( assert dict_persistence.conversations_json == encode_conversations_to_json(
@ -1046,3 +1098,53 @@ class TestDictPersistence(object):
dp.process_update(update) dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 0 assert ch.conversations[ch._get_key(update)] == 0
assert ch.conversations == dict_persistence.conversations['name2'] assert ch.conversations == dict_persistence.conversations['name2']
def test_with_nested_conversationHandler(self, dp, update, conversations_json):
dict_persistence = DictPersistence(conversations_json=conversations_json)
dp.persistence = dict_persistence
dp.use_context = True
NEXT2, NEXT3 = range(1, 3)
def start(update, context):
return NEXT2
start = CommandHandler('start', start)
def next(update, context):
return NEXT2
next = MessageHandler(None, next)
def next2(update, context):
return ConversationHandler.END
next2 = MessageHandler(None, next2)
nested_ch = ConversationHandler(
[next],
{NEXT2: [next2]},
[],
name='name3',
persistent=True,
map_to_parent={ConversationHandler.END: ConversationHandler.END},
)
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
persistent=True)
dp.add_handler(ch)
assert ch.conversations[ch._get_key(update)] == 1
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
dp.process_update(update)
assert ch._get_key(update) not in ch.conversations
assert nested_ch._get_key(update) not in nested_ch.conversations
update.message.text = '/start'
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 1
assert ch.conversations == dict_persistence.conversations['name2']
assert nested_ch._get_key(update) not in nested_ch.conversations
dp.process_update(update)
assert ch.conversations[ch._get_key(update)] == 1
assert ch.conversations == dict_persistence.conversations['name2']
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
assert nested_ch.conversations == dict_persistence.conversations['name3']