mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-16 14:33:12 +01:00
Handler persistence for nested ConversationHandlers (#1711)
* Handler persistence for nested ConversationHandlers * Add tests for persistence w/ nested CHs
This commit is contained in:
parent
f6b663f175
commit
6d9d11b8bd
3 changed files with 134 additions and 6 deletions
|
@ -178,14 +178,14 @@ class ConversationHandler(Handler):
|
|||
if persistent and not self.name:
|
||||
raise ValueError("Conversations can't be persistent when handler is unnamed.")
|
||||
self.persistent = persistent
|
||||
self.persistence = None
|
||||
self._persistence = None
|
||||
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
|
||||
Set by dispatcher"""
|
||||
self.map_to_parent = map_to_parent
|
||||
|
||||
self.timeout_jobs = dict()
|
||||
self._timeout_jobs_lock = Lock()
|
||||
self.conversations = dict()
|
||||
self._conversations = dict()
|
||||
self._conversations_lock = Lock()
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
@ -225,6 +225,32 @@ class ConversationHandler(Handler):
|
|||
"since inline queries have no chat context.")
|
||||
break
|
||||
|
||||
@property
|
||||
def persistence(self):
|
||||
return self._persistence
|
||||
|
||||
@persistence.setter
|
||||
def persistence(self, persistence):
|
||||
self._persistence = persistence
|
||||
# Set persistence for nested conversations
|
||||
for handlers in self.states.values():
|
||||
for handler in handlers:
|
||||
if isinstance(handler, ConversationHandler):
|
||||
handler.persistence = self.persistence
|
||||
|
||||
@property
|
||||
def conversations(self):
|
||||
return self._conversations
|
||||
|
||||
@conversations.setter
|
||||
def conversations(self, value):
|
||||
self._conversations = value
|
||||
# Set conversations for nested conversations
|
||||
for handlers in self.states.values():
|
||||
for handler in handlers:
|
||||
if isinstance(handler, ConversationHandler):
|
||||
handler.conversations = self.persistence.get_conversations(handler.name)
|
||||
|
||||
def _get_key(self, update):
|
||||
chat = update.effective_chat
|
||||
user = update.effective_user
|
||||
|
|
|
@ -447,8 +447,8 @@ class Dispatcher(object):
|
|||
raise ValueError(
|
||||
"Conversationhandler {} can not be persistent if dispatcher has no "
|
||||
"persistence".format(handler.name))
|
||||
handler.conversations = self.persistence.get_conversations(handler.name)
|
||||
handler.persistence = self.persistence
|
||||
handler.conversations = self.persistence.get_conversations(handler.name)
|
||||
|
||||
if group not in self.handlers:
|
||||
self.handlers[group] = list()
|
||||
|
|
|
@ -59,7 +59,8 @@ def user_data():
|
|||
@pytest.fixture(scope='function')
|
||||
def conversations():
|
||||
return {'name1': {(123, 123): 3, (456, 654): 4},
|
||||
'name2': {(123, 321): 1, (890, 890): 2}}
|
||||
'name2': {(123, 321): 1, (890, 890): 2},
|
||||
'name3': {(123, 321): 1, (890, 890): 2}}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -806,6 +807,56 @@ class TestPickelPersistence(object):
|
|||
assert ch.conversations[ch._get_key(update)] == 0
|
||||
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||
|
||||
def test_with_nested_conversationHandler(self, dp, update, good_pickle_files,
|
||||
pickle_persistence):
|
||||
dp.persistence = pickle_persistence
|
||||
dp.use_context = True
|
||||
NEXT2, NEXT3 = range(1, 3)
|
||||
|
||||
def start(update, context):
|
||||
return NEXT2
|
||||
|
||||
start = CommandHandler('start', start)
|
||||
|
||||
def next(update, context):
|
||||
return NEXT2
|
||||
|
||||
next = MessageHandler(None, next)
|
||||
|
||||
def next2(update, context):
|
||||
return ConversationHandler.END
|
||||
|
||||
next2 = MessageHandler(None, next2)
|
||||
|
||||
nested_ch = ConversationHandler(
|
||||
[next],
|
||||
{NEXT2: [next2]},
|
||||
[],
|
||||
name='name3',
|
||||
persistent=True,
|
||||
map_to_parent={ConversationHandler.END: ConversationHandler.END},
|
||||
)
|
||||
|
||||
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
|
||||
persistent=True)
|
||||
dp.add_handler(ch)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
dp.process_update(update)
|
||||
assert ch._get_key(update) not in ch.conversations
|
||||
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||
update.message.text = '/start'
|
||||
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
|
||||
dp.process_update(update)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||
dp.process_update(update)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations == pickle_persistence.conversations['name3']
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
|
@ -836,6 +887,7 @@ def bot_data_json(bot_data):
|
|||
@pytest.fixture(scope='function')
|
||||
def conversations_json(conversations):
|
||||
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
|
||||
{"[123, 321]": 1, "[890, 890]": 2}, "name3":
|
||||
{"[123, 321]": 1, "[890, 890]": 2}}"""
|
||||
|
||||
|
||||
|
@ -964,8 +1016,8 @@ class TestDictPersistence(object):
|
|||
assert dict_persistence.bot_data_json == json.dumps(bot_data_two)
|
||||
|
||||
conversations_two = conversations.copy()
|
||||
conversations_two.update({'name3': {(1, 2): 3}})
|
||||
dict_persistence.update_conversation('name3', (1, 2), 3)
|
||||
conversations_two.update({'name4': {(1, 2): 3}})
|
||||
dict_persistence.update_conversation('name4', (1, 2), 3)
|
||||
assert dict_persistence.conversations == conversations_two
|
||||
assert dict_persistence.conversations_json != conversations_json
|
||||
assert dict_persistence.conversations_json == encode_conversations_to_json(
|
||||
|
@ -1046,3 +1098,53 @@ class TestDictPersistence(object):
|
|||
dp.process_update(update)
|
||||
assert ch.conversations[ch._get_key(update)] == 0
|
||||
assert ch.conversations == dict_persistence.conversations['name2']
|
||||
|
||||
def test_with_nested_conversationHandler(self, dp, update, conversations_json):
|
||||
dict_persistence = DictPersistence(conversations_json=conversations_json)
|
||||
dp.persistence = dict_persistence
|
||||
dp.use_context = True
|
||||
NEXT2, NEXT3 = range(1, 3)
|
||||
|
||||
def start(update, context):
|
||||
return NEXT2
|
||||
|
||||
start = CommandHandler('start', start)
|
||||
|
||||
def next(update, context):
|
||||
return NEXT2
|
||||
|
||||
next = MessageHandler(None, next)
|
||||
|
||||
def next2(update, context):
|
||||
return ConversationHandler.END
|
||||
|
||||
next2 = MessageHandler(None, next2)
|
||||
|
||||
nested_ch = ConversationHandler(
|
||||
[next],
|
||||
{NEXT2: [next2]},
|
||||
[],
|
||||
name='name3',
|
||||
persistent=True,
|
||||
map_to_parent={ConversationHandler.END: ConversationHandler.END},
|
||||
)
|
||||
|
||||
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
|
||||
persistent=True)
|
||||
dp.add_handler(ch)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
dp.process_update(update)
|
||||
assert ch._get_key(update) not in ch.conversations
|
||||
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||
update.message.text = '/start'
|
||||
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
|
||||
dp.process_update(update)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert ch.conversations == dict_persistence.conversations['name2']
|
||||
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||
dp.process_update(update)
|
||||
assert ch.conversations[ch._get_key(update)] == 1
|
||||
assert ch.conversations == dict_persistence.conversations['name2']
|
||||
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||
assert nested_ch.conversations == dict_persistence.conversations['name3']
|
||||
|
|
Loading…
Reference in a new issue