mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-16 14:33:12 +01:00
Handler persistence for nested ConversationHandlers (#1711)
* Handler persistence for nested ConversationHandlers * Add tests for persistence w/ nested CHs
This commit is contained in:
parent
f6b663f175
commit
6d9d11b8bd
3 changed files with 134 additions and 6 deletions
|
@ -178,14 +178,14 @@ class ConversationHandler(Handler):
|
||||||
if persistent and not self.name:
|
if persistent and not self.name:
|
||||||
raise ValueError("Conversations can't be persistent when handler is unnamed.")
|
raise ValueError("Conversations can't be persistent when handler is unnamed.")
|
||||||
self.persistent = persistent
|
self.persistent = persistent
|
||||||
self.persistence = None
|
self._persistence = None
|
||||||
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
|
""":obj:`telegram.ext.BasePersistance`: The persistence used to store conversations.
|
||||||
Set by dispatcher"""
|
Set by dispatcher"""
|
||||||
self.map_to_parent = map_to_parent
|
self.map_to_parent = map_to_parent
|
||||||
|
|
||||||
self.timeout_jobs = dict()
|
self.timeout_jobs = dict()
|
||||||
self._timeout_jobs_lock = Lock()
|
self._timeout_jobs_lock = Lock()
|
||||||
self.conversations = dict()
|
self._conversations = dict()
|
||||||
self._conversations_lock = Lock()
|
self._conversations_lock = Lock()
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
@ -225,6 +225,32 @@ class ConversationHandler(Handler):
|
||||||
"since inline queries have no chat context.")
|
"since inline queries have no chat context.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@property
|
||||||
|
def persistence(self):
|
||||||
|
return self._persistence
|
||||||
|
|
||||||
|
@persistence.setter
|
||||||
|
def persistence(self, persistence):
|
||||||
|
self._persistence = persistence
|
||||||
|
# Set persistence for nested conversations
|
||||||
|
for handlers in self.states.values():
|
||||||
|
for handler in handlers:
|
||||||
|
if isinstance(handler, ConversationHandler):
|
||||||
|
handler.persistence = self.persistence
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conversations(self):
|
||||||
|
return self._conversations
|
||||||
|
|
||||||
|
@conversations.setter
|
||||||
|
def conversations(self, value):
|
||||||
|
self._conversations = value
|
||||||
|
# Set conversations for nested conversations
|
||||||
|
for handlers in self.states.values():
|
||||||
|
for handler in handlers:
|
||||||
|
if isinstance(handler, ConversationHandler):
|
||||||
|
handler.conversations = self.persistence.get_conversations(handler.name)
|
||||||
|
|
||||||
def _get_key(self, update):
|
def _get_key(self, update):
|
||||||
chat = update.effective_chat
|
chat = update.effective_chat
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
|
|
|
@ -447,8 +447,8 @@ class Dispatcher(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Conversationhandler {} can not be persistent if dispatcher has no "
|
"Conversationhandler {} can not be persistent if dispatcher has no "
|
||||||
"persistence".format(handler.name))
|
"persistence".format(handler.name))
|
||||||
handler.conversations = self.persistence.get_conversations(handler.name)
|
|
||||||
handler.persistence = self.persistence
|
handler.persistence = self.persistence
|
||||||
|
handler.conversations = self.persistence.get_conversations(handler.name)
|
||||||
|
|
||||||
if group not in self.handlers:
|
if group not in self.handlers:
|
||||||
self.handlers[group] = list()
|
self.handlers[group] = list()
|
||||||
|
|
|
@ -59,7 +59,8 @@ def user_data():
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture(scope='function')
|
||||||
def conversations():
|
def conversations():
|
||||||
return {'name1': {(123, 123): 3, (456, 654): 4},
|
return {'name1': {(123, 123): 3, (456, 654): 4},
|
||||||
'name2': {(123, 321): 1, (890, 890): 2}}
|
'name2': {(123, 321): 1, (890, 890): 2},
|
||||||
|
'name3': {(123, 321): 1, (890, 890): 2}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
@ -806,6 +807,56 @@ class TestPickelPersistence(object):
|
||||||
assert ch.conversations[ch._get_key(update)] == 0
|
assert ch.conversations[ch._get_key(update)] == 0
|
||||||
assert ch.conversations == pickle_persistence.conversations['name2']
|
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||||
|
|
||||||
|
def test_with_nested_conversationHandler(self, dp, update, good_pickle_files,
|
||||||
|
pickle_persistence):
|
||||||
|
dp.persistence = pickle_persistence
|
||||||
|
dp.use_context = True
|
||||||
|
NEXT2, NEXT3 = range(1, 3)
|
||||||
|
|
||||||
|
def start(update, context):
|
||||||
|
return NEXT2
|
||||||
|
|
||||||
|
start = CommandHandler('start', start)
|
||||||
|
|
||||||
|
def next(update, context):
|
||||||
|
return NEXT2
|
||||||
|
|
||||||
|
next = MessageHandler(None, next)
|
||||||
|
|
||||||
|
def next2(update, context):
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
next2 = MessageHandler(None, next2)
|
||||||
|
|
||||||
|
nested_ch = ConversationHandler(
|
||||||
|
[next],
|
||||||
|
{NEXT2: [next2]},
|
||||||
|
[],
|
||||||
|
name='name3',
|
||||||
|
persistent=True,
|
||||||
|
map_to_parent={ConversationHandler.END: ConversationHandler.END},
|
||||||
|
)
|
||||||
|
|
||||||
|
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
|
||||||
|
persistent=True)
|
||||||
|
dp.add_handler(ch)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch._get_key(update) not in ch.conversations
|
||||||
|
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||||
|
update.message.text = '/start'
|
||||||
|
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||||
|
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert ch.conversations == pickle_persistence.conversations['name2']
|
||||||
|
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||||
|
assert nested_ch.conversations == pickle_persistence.conversations['name3']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def teardown_class(cls):
|
def teardown_class(cls):
|
||||||
try:
|
try:
|
||||||
|
@ -836,6 +887,7 @@ def bot_data_json(bot_data):
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture(scope='function')
|
||||||
def conversations_json(conversations):
|
def conversations_json(conversations):
|
||||||
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
|
return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2":
|
||||||
|
{"[123, 321]": 1, "[890, 890]": 2}, "name3":
|
||||||
{"[123, 321]": 1, "[890, 890]": 2}}"""
|
{"[123, 321]": 1, "[890, 890]": 2}}"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -964,8 +1016,8 @@ class TestDictPersistence(object):
|
||||||
assert dict_persistence.bot_data_json == json.dumps(bot_data_two)
|
assert dict_persistence.bot_data_json == json.dumps(bot_data_two)
|
||||||
|
|
||||||
conversations_two = conversations.copy()
|
conversations_two = conversations.copy()
|
||||||
conversations_two.update({'name3': {(1, 2): 3}})
|
conversations_two.update({'name4': {(1, 2): 3}})
|
||||||
dict_persistence.update_conversation('name3', (1, 2), 3)
|
dict_persistence.update_conversation('name4', (1, 2), 3)
|
||||||
assert dict_persistence.conversations == conversations_two
|
assert dict_persistence.conversations == conversations_two
|
||||||
assert dict_persistence.conversations_json != conversations_json
|
assert dict_persistence.conversations_json != conversations_json
|
||||||
assert dict_persistence.conversations_json == encode_conversations_to_json(
|
assert dict_persistence.conversations_json == encode_conversations_to_json(
|
||||||
|
@ -1046,3 +1098,53 @@ class TestDictPersistence(object):
|
||||||
dp.process_update(update)
|
dp.process_update(update)
|
||||||
assert ch.conversations[ch._get_key(update)] == 0
|
assert ch.conversations[ch._get_key(update)] == 0
|
||||||
assert ch.conversations == dict_persistence.conversations['name2']
|
assert ch.conversations == dict_persistence.conversations['name2']
|
||||||
|
|
||||||
|
def test_with_nested_conversationHandler(self, dp, update, conversations_json):
|
||||||
|
dict_persistence = DictPersistence(conversations_json=conversations_json)
|
||||||
|
dp.persistence = dict_persistence
|
||||||
|
dp.use_context = True
|
||||||
|
NEXT2, NEXT3 = range(1, 3)
|
||||||
|
|
||||||
|
def start(update, context):
|
||||||
|
return NEXT2
|
||||||
|
|
||||||
|
start = CommandHandler('start', start)
|
||||||
|
|
||||||
|
def next(update, context):
|
||||||
|
return NEXT2
|
||||||
|
|
||||||
|
next = MessageHandler(None, next)
|
||||||
|
|
||||||
|
def next2(update, context):
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
next2 = MessageHandler(None, next2)
|
||||||
|
|
||||||
|
nested_ch = ConversationHandler(
|
||||||
|
[next],
|
||||||
|
{NEXT2: [next2]},
|
||||||
|
[],
|
||||||
|
name='name3',
|
||||||
|
persistent=True,
|
||||||
|
map_to_parent={ConversationHandler.END: ConversationHandler.END},
|
||||||
|
)
|
||||||
|
|
||||||
|
ch = ConversationHandler([start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2',
|
||||||
|
persistent=True)
|
||||||
|
dp.add_handler(ch)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch._get_key(update) not in ch.conversations
|
||||||
|
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||||
|
update.message.text = '/start'
|
||||||
|
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)]
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert ch.conversations == dict_persistence.conversations['name2']
|
||||||
|
assert nested_ch._get_key(update) not in nested_ch.conversations
|
||||||
|
dp.process_update(update)
|
||||||
|
assert ch.conversations[ch._get_key(update)] == 1
|
||||||
|
assert ch.conversations == dict_persistence.conversations['name2']
|
||||||
|
assert nested_ch.conversations[nested_ch._get_key(update)] == 1
|
||||||
|
assert nested_ch.conversations == dict_persistence.conversations['name3']
|
||||||
|
|
Loading…
Reference in a new issue