mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-10 12:02:39 +01:00
Issue 502 (#530)
* conversationhandler.py: add per_chat, per_user and per_message
* test_conversationhandler.py: test case per_user=False
* test_conversationhandler.py: add test for callbackqueryhandlers
* ✏️ Fix accidental typo in logging format
This commit is contained in:
parent
853d823964
commit
cc73469dab
2 changed files with 200 additions and 17 deletions
|
@ -21,7 +21,8 @@
|
|||
import logging
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import Handler
|
||||
from telegram.ext import (Handler, CallbackQueryHandler, InlineQueryHandler,
|
||||
ChosenInlineResultHandler)
|
||||
from telegram.utils.promise import Promise
|
||||
|
||||
|
||||
|
@ -87,7 +88,10 @@ class ConversationHandler(Handler):
|
|||
fallbacks,
|
||||
allow_reentry=False,
|
||||
run_async_timeout=None,
|
||||
timed_out_behavior=None):
|
||||
timed_out_behavior=None,
|
||||
per_chat=True,
|
||||
per_user=True,
|
||||
per_message=False):
|
||||
|
||||
self.entry_points = entry_points
|
||||
""":type: list[telegram.ext.Handler]"""
|
||||
|
@ -105,22 +109,66 @@ class ConversationHandler(Handler):
|
|||
""":type: list[telegram.ext.Handler]"""
|
||||
|
||||
self.conversations = dict()
|
||||
""":type: dict[(int, int): str]"""
|
||||
self.per_user = per_user
|
||||
self.per_chat = per_chat
|
||||
self.per_message = per_message
|
||||
""":type: dict[tuple: object]"""
|
||||
|
||||
self.current_conversation = None
|
||||
self.current_handler = None
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
if not any((self.per_user, self.per_chat, self.per_message)):
|
||||
raise ValueError("'per_user', 'per_chat' and 'per_message' can't all be 'False'")
|
||||
|
||||
all_handlers = list()
|
||||
all_handlers.extend(entry_points)
|
||||
all_handlers.extend(fallbacks)
|
||||
|
||||
for state_handlers in states.values():
|
||||
all_handlers.extend(state_handlers)
|
||||
|
||||
if self.per_message:
|
||||
for handler in all_handlers:
|
||||
if not isinstance(handler, CallbackQueryHandler):
|
||||
raise ValueError("If 'per_message=True', all entry points and state handlers"
|
||||
" must be 'CallbackQueryHandler'")
|
||||
else:
|
||||
for handler in all_handlers:
|
||||
if isinstance(handler, CallbackQueryHandler):
|
||||
raise ValueError("If 'per_message=False', 'CallbackQueryHandler' doesn't work")
|
||||
|
||||
if self.per_chat:
|
||||
for handler in all_handlers:
|
||||
if isinstance(handler, (InlineQueryHandler, ChosenInlineResultHandler)):
|
||||
raise ValueError("If 'per_chat=True', 'InlineQueryHandler' doesn't work")
|
||||
|
||||
def _get_key(self, update):
|
||||
chat, user = update.extract_chat_and_user()
|
||||
key = list()
|
||||
|
||||
if self.per_chat:
|
||||
key.append(chat.id)
|
||||
|
||||
if self.per_user:
|
||||
key.append(user.id)
|
||||
|
||||
if self.per_message:
|
||||
key.append(update.callback_query.inline_message_id
|
||||
or update.callback_query.message.message_id)
|
||||
|
||||
return tuple(key)
|
||||
|
||||
def check_update(self, update):
|
||||
|
||||
# Ignore messages in channels
|
||||
if not isinstance(update, Update) or update.channel_post:
|
||||
if (not isinstance(update, Update) or update.channel_post or self.per_chat
|
||||
and (update.inline_query or update.chosen_inline_result) or self.per_message
|
||||
and not update.callback_query):
|
||||
return False
|
||||
|
||||
chat, user = update.extract_chat_and_user()
|
||||
|
||||
key = (chat.id, user.id) if chat else (None, user.id)
|
||||
key = self._get_key(update)
|
||||
state = self.conversations.get(key)
|
||||
|
||||
# Resolve promises
|
||||
|
|
|
@ -35,8 +35,9 @@ except ImportError:
|
|||
|
||||
sys.path.append('.')
|
||||
|
||||
from telegram import Update, Message, TelegramError, User, Chat, Bot
|
||||
from telegram.ext import Updater, ConversationHandler, CommandHandler
|
||||
from telegram import Update, Message, TelegramError, User, Chat, Bot, CallbackQuery
|
||||
from telegram.ext import (Updater, ConversationHandler, CommandHandler, CallbackQueryHandler,
|
||||
InlineQueryHandler)
|
||||
from tests.base import BaseTest
|
||||
from tests.test_updater import MockBot
|
||||
|
||||
|
@ -79,6 +80,12 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
|||
}
|
||||
self.fallbacks = [CommandHandler('eat', self.start)]
|
||||
|
||||
self.group = Chat(0, Chat.GROUP)
|
||||
self.second_group = Chat(1, Chat.GROUP)
|
||||
|
||||
def _chat(self, user):
|
||||
return Chat(user.id, Chat.GROUP)
|
||||
|
||||
def _setup_updater(self, *args, **kwargs):
|
||||
self.bot = MockBot(*args, **kwargs)
|
||||
self.updater = Updater(workers=2, bot=self.bot)
|
||||
|
@ -137,36 +144,164 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
|||
queue = self.updater.start_polling(0.01)
|
||||
|
||||
# User one, starts the state machine.
|
||||
message = Message(0, user, None, None, text="/start", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="/start", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertTrue(self.current_state[user.id] == self.THIRSTY)
|
||||
|
||||
# The user is thirsty and wants to brew coffee.
|
||||
message = Message(0, user, None, None, text="/brew", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="/brew", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertTrue(self.current_state[user.id] == self.BREWING)
|
||||
|
||||
# Lets see if an invalid command makes sure, no state is changed.
|
||||
message = Message(0, user, None, None, text="/nothing", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="/nothing", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertTrue(self.current_state[user.id] == self.BREWING)
|
||||
|
||||
# Lets see if the state machine still works by pouring coffee.
|
||||
message = Message(0, user, None, None, text="/pourCoffee", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="/pourCoffee", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertTrue(self.current_state[user.id] == self.DRINKING)
|
||||
|
||||
# Let's now verify that for another user, who did not start yet,
|
||||
# the state has not been changed.
|
||||
message = Message(0, second_user, None, None, text="/brew", bot=self.bot)
|
||||
message = Message(0, second_user, None, self.group, text="/brew", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertRaises(KeyError, self._get_state, user_id=second_user.id)
|
||||
|
||||
def test_addConversationHandlerPerChat(self):
|
||||
self._setup_updater('', messages=0)
|
||||
d = self.updater.dispatcher
|
||||
user = User(first_name="Misses Test", id=123)
|
||||
second_user = User(first_name="Mister Test", id=124)
|
||||
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states=self.states,
|
||||
fallbacks=self.fallbacks,
|
||||
per_user=False)
|
||||
d.add_handler(handler)
|
||||
queue = self.updater.start_polling(0.01)
|
||||
|
||||
# User one, starts the state machine.
|
||||
message = Message(0, user, None, self.group, text="/start", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
|
||||
# The user is thirsty and wants to brew coffee.
|
||||
message = Message(0, user, None, self.group, text="/brew", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
|
||||
# Let's now verify that for another user, who did not start yet,
|
||||
# the state will be changed because they are in the same group.
|
||||
message = Message(0, second_user, None, self.group, text="/pourCoffee", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertEquals(handler.conversations[(self.group.id,)], self.DRINKING)
|
||||
|
||||
def test_addConversationHandlerPerUser(self):
|
||||
self._setup_updater('', messages=0)
|
||||
d = self.updater.dispatcher
|
||||
user = User(first_name="Misses Test", id=123)
|
||||
|
||||
handler = ConversationHandler(
|
||||
entry_points=self.entry_points,
|
||||
states=self.states,
|
||||
fallbacks=self.fallbacks,
|
||||
per_chat=False)
|
||||
d.add_handler(handler)
|
||||
queue = self.updater.start_polling(0.01)
|
||||
|
||||
# User one, starts the state machine.
|
||||
message = Message(0, user, None, self.group, text="/start", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
|
||||
# The user is thirsty and wants to brew coffee.
|
||||
message = Message(0, user, None, self.group, text="/brew", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
|
||||
# Let's now verify that for the same user in a different group, the state will still be
|
||||
# updated
|
||||
message = Message(0, user, None, self.second_group, text="/pourCoffee", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
|
||||
self.assertEquals(handler.conversations[(user.id,)], self.DRINKING)
|
||||
|
||||
def test_addConversationHandlerPerMessage(self):
|
||||
self._setup_updater('', messages=0)
|
||||
d = self.updater.dispatcher
|
||||
user = User(first_name="Misses Test", id=123)
|
||||
second_user = User(first_name="Mister Test", id=124)
|
||||
|
||||
def entry(bot, update):
|
||||
return 1
|
||||
|
||||
def one(bot, update):
|
||||
return 2
|
||||
|
||||
def two(bot, update):
|
||||
return ConversationHandler.END
|
||||
|
||||
handler = ConversationHandler(
|
||||
entry_points=[CallbackQueryHandler(entry)],
|
||||
states={1: [CallbackQueryHandler(one)],
|
||||
2: [CallbackQueryHandler(two)]},
|
||||
fallbacks=[],
|
||||
per_message=True)
|
||||
d.add_handler(handler)
|
||||
queue = self.updater.start_polling(0.01)
|
||||
|
||||
# User one, starts the state machine.
|
||||
message = Message(0, user, None, self.group, text="msg w/ inlinekeyboard", bot=self.bot)
|
||||
|
||||
cbq = CallbackQuery(0, user, None, message=message, data='data', bot=self.bot)
|
||||
queue.put(Update(update_id=0, callback_query=cbq))
|
||||
sleep(.1)
|
||||
self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 1)
|
||||
|
||||
cbq = CallbackQuery(0, user, None, message=message, data='data', bot=self.bot)
|
||||
queue.put(Update(update_id=0, callback_query=cbq))
|
||||
sleep(.1)
|
||||
self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 2)
|
||||
|
||||
# Let's now verify that for a different user in the same group, the state will not be
|
||||
# updated
|
||||
cbq = CallbackQuery(0, second_user, None, message=message, data='data', bot=self.bot)
|
||||
queue.put(Update(update_id=0, callback_query=cbq))
|
||||
sleep(.1)
|
||||
self.assertEquals(handler.conversations[(self.group.id, user.id, message.message_id)], 2)
|
||||
|
||||
def test_illegal_handlers(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ConversationHandler(
|
||||
entry_points=[CommandHandler('/test', lambda bot, update: None)],
|
||||
states={},
|
||||
fallbacks=[],
|
||||
per_message=True)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
ConversationHandler(
|
||||
entry_points=[CallbackQueryHandler(lambda bot, update: None)],
|
||||
states={},
|
||||
fallbacks=[],
|
||||
per_message=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
ConversationHandler(
|
||||
entry_points=[InlineQueryHandler(lambda bot, update: None)],
|
||||
states={},
|
||||
fallbacks=[],
|
||||
per_chat=True)
|
||||
|
||||
def test_endOnFirstMessage(self):
|
||||
self._setup_updater('', messages=0)
|
||||
d = self.updater.dispatcher
|
||||
|
@ -178,7 +313,7 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
|||
queue = self.updater.start_polling(0.01)
|
||||
|
||||
# User starts the state machine and immediately ends it.
|
||||
message = Message(0, user, None, None, text="/start")
|
||||
message = Message(0, user, None, self.group, text="/start", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
self.assertEquals(len(handler.conversations), 0)
|
||||
|
@ -197,13 +332,13 @@ class ConversationHandlerTest(BaseTest, unittest.TestCase):
|
|||
|
||||
# User starts the state machine with an async function that immediately ends the
|
||||
# conversation. Async results are resolved when the users state is queried next time.
|
||||
message = Message(0, user, None, None, text="/start", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="/start", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
# Assert that the Promise has been accepted as the new state
|
||||
self.assertEquals(len(handler.conversations), 1)
|
||||
|
||||
message = Message(0, user, None, None, text="resolve promise pls", bot=self.bot)
|
||||
message = Message(0, user, None, self.group, text="resolve promise pls", bot=self.bot)
|
||||
queue.put(Update(update_id=0, message=message))
|
||||
sleep(.1)
|
||||
# Assert that the Promise has been resolved and the conversation ended.
|
||||
|
|
Loading…
Reference in a new issue