diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index 8a950d402..6afa90575 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -215,9 +215,11 @@ class ConversationHandler(Handler): if new_state == self.END: if key in self.conversations: del self.conversations[key] + else: + pass elif isinstance(new_state, Promise): - self.conversations[key] = (self.conversations[key], new_state) + self.conversations[key] = (self.conversations.get(key), new_state) elif new_state is not None: self.conversations[key] = new_state diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index 6f6cd4ba3..f9ee4623b 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -36,7 +36,7 @@ except ImportError: sys.path.append('.') from telegram import Update, Message, TelegramError, User, Chat, Bot -from telegram.ext import * +from telegram.ext import Updater, ConversationHandler, CommandHandler from tests.base import BaseTest from tests.test_updater import MockBot @@ -109,6 +109,9 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): def start(self, bot, update): return self._set_state(update, self.THIRSTY) + def start_end(self, bot, update): + return self._set_state(update, self.END) + def brew(self, bot, update): return self._set_state(update, self.BREWING) @@ -161,6 +164,48 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase): sleep(.1) self.assertRaises(KeyError, self._get_state, user_id=second_user.id) + def test_endOnFirstMessage(self): + self._setup_updater('', messages=0) + d = self.updater.dispatcher + user = User(first_name="Misses Test", id=123) + + handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[]) + d.add_handler(handler) + queue = self.updater.start_polling(0.01) + + # User starts the state machine and immediately ends it. + message = Message(0, user, None, None, text="/start") + queue.put(Update(update_id=0, message=message)) + sleep(.1) + self.assertEquals(len(handler.conversations), 0) + + def test_endOnFirstMessageAsync(self): + self._setup_updater('', messages=0) + d = self.updater.dispatcher + user = User(first_name="Misses Test", id=123) + + start_end_async = (lambda bot, update: d.run_async(self.start_end, bot, update)) + + handler = ConversationHandler( + entry_points=[CommandHandler('start', start_end_async)], states={}, fallbacks=[]) + d.add_handler(handler) + queue = self.updater.start_polling(0.01) + + # User starts the state machine with an async function that immediately ends the + # conversation. Async results are resolved when the users state is queried next time. + message = Message(0, user, None, None, text="/start") + queue.put(Update(update_id=0, message=message)) + sleep(.1) + # Assert that the Promise has been accepted as the new state + self.assertEquals(len(handler.conversations), 1) + + message = Message(0, user, None, None, text="resolve promise pls") + queue.put(Update(update_id=0, message=message)) + sleep(.1) + # Assert that the Promise has been resolved and the conversation ended. + self.assertEquals(len(handler.conversations), 0) + if __name__ == '__main__': unittest.main()