From a0bb5730c653810a2372ee3d27ac2c923da24525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannes=20H=C3=B6ke?= Date: Fri, 27 May 2016 11:06:29 +0200 Subject: [PATCH] add allow_edited parameter to MessageHandler and CommandHandler --- telegram/ext/commandhandler.py | 26 +++++++++--- telegram/ext/messagehandler.py | 72 +++++++++++++++++++--------------- tests/test_updater.py | 59 ++++++++++++++++++++++++++-- 3 files changed, 116 insertions(+), 41 deletions(-) diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py index 2d22f9221..4c2a98c56 100644 --- a/telegram/ext/commandhandler.py +++ b/telegram/ext/commandhandler.py @@ -34,6 +34,8 @@ class CommandHandler(Handler): callback (function): A function that takes ``bot, update`` as positional arguments. It will be called when the ``check_update`` has determined that an update should be processed by this handler. + allow_edited (Optional[bool]): If the handler should also accept edited messages. + Default is ``False`` pass_args (optional[bool]): If the handler should be passed the arguments passed to the command as a keyword argument called ` ``args``. It will contain a list of strings, which is the text @@ -43,21 +45,35 @@ class CommandHandler(Handler): be used to insert updates. Default is ``False`` """ - def __init__(self, command, callback, pass_args=False, pass_update_queue=False): + def __init__(self, + command, + callback, + allow_edited=False, + pass_args=False, + pass_update_queue=False): super(CommandHandler, self).__init__(callback, pass_update_queue) self.command = command + self.allow_edited = allow_edited self.pass_args = pass_args def check_update(self, update): - return (isinstance(update, Update) and update.message and update.message.text - and update.message.text.startswith('/') - and update.message.text[1:].split(' ')[0].split('@')[0] == self.command) + if (isinstance(update, Update) + and (update.message or update.edited_message and self.allow_edited)): + message = update.message or update.edited_message + + return (message.text and message.text.startswith('/') + and message.text[1:].split(' ')[0].split('@')[0] == self.command) + + else: + return False def handle_update(self, update, dispatcher): optional_args = self.collect_optional_args(dispatcher) + message = update.message or update.edited_message + if self.pass_args: - optional_args['args'] = update.message.text.split(' ')[1:] + optional_args['args'] = message.text.split(' ')[1:] self.callback(dispatcher.bot, update, **optional_args) diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py index fd3248762..ddef6bdb1 100644 --- a/telegram/ext/messagehandler.py +++ b/telegram/ext/messagehandler.py @@ -30,57 +30,56 @@ class Filters(object): """ @staticmethod - def text(update): - return update.message.text and not update.message.text.startswith('/') + def text(message): + return message.text and not message.text.startswith('/') @staticmethod - def command(update): - return update.message.text and update.message.text.startswith('/') + def command(message): + return message.text and message.text.startswith('/') @staticmethod - def audio(update): - return bool(update.message.audio) + def audio(message): + return bool(message.audio) @staticmethod - def document(update): - return bool(update.message.document) + def document(message): + return bool(message.document) @staticmethod - def photo(update): - return bool(update.message.photo) + def photo(message): + return bool(message.photo) @staticmethod - def sticker(update): - return bool(update.message.sticker) + def sticker(message): + return bool(message.sticker) @staticmethod - def video(update): - return bool(update.message.video) + def video(message): + return bool(message.video) @staticmethod - def voice(update): - return bool(update.message.voice) + def voice(message): + return bool(message.voice) @staticmethod - def contact(update): - return bool(update.message.contact) + def contact(message): + return bool(message.contact) @staticmethod - def location(update): - return bool(update.message.location) + def location(message): + return bool(message.location) @staticmethod - def venue(update): - return bool(update.message.venue) + def venue(message): + return bool(message.venue) @staticmethod - def status_update(update): - return bool(update.message.new_chat_member or update.message.left_chat_member - or update.message.new_chat_title or update.message.new_chat_photo - or update.message.delete_chat_photo or update.message.group_chat_created - or update.message.supergroup_chat_created - or update.message.channel_chat_created or update.message.migrate_to_chat_id - or update.message.migrate_from_chat_id or update.message.pinned_message) + def status_update(message): + return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title + or message.new_chat_photo or message.delete_chat_photo + or message.group_chat_created or message.supergroup_chat_created + or message.channel_chat_created or message.migrate_to_chat_id + or message.migrate_from_chat_id or message.pinned_message) class MessageHandler(Handler): @@ -99,23 +98,32 @@ class MessageHandler(Handler): callback (function): A function that takes ``bot, update`` as positional arguments. It will be called when the ``check_update`` has determined that an update should be processed by this handler. + allow_edited (Optional[bool]): If the handler should also accept edited messages. + Default is ``False`` pass_update_queue (optional[bool]): If the handler should be passed the update queue as a keyword argument called ``update_queue``. It can be used to insert updates. Default is ``False`` """ - def __init__(self, filters, callback, pass_update_queue=False): + def __init__(self, filters, callback, allow_edited=False, pass_update_queue=False): super(MessageHandler, self).__init__(callback, pass_update_queue) self.filters = filters + self.allow_edited = allow_edited def check_update(self, update): - if isinstance(update, Update) and update.message: + if (isinstance(update, Update) + and (update.message or update.edited_message and self.allow_edited)): + if not self.filters: res = True + else: - res = any(func(update) for func in self.filters) + message = update.message or update.edited_message + res = any(func(message) for func in self.filters) + else: res = False + return res def handle_update(self, update, dispatcher): diff --git a/tests/test_updater.py b/tests/test_updater.py index ab12d9bda..192416ba1 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -93,6 +93,10 @@ class UpdaterTest(BaseTest, unittest.TestCase): self.received_message = update.message.text self.message_count += 1 + def telegramHandlerEditedTest(self, bot, update): + self.received_message = update.edited_message.text + self.message_count += 1 + def telegramInlineHandlerTest(self, bot, update): self.received_message = (update.inline_query, update.chosen_inline_result) self.message_count += 1 @@ -157,6 +161,28 @@ class UpdaterTest(BaseTest, unittest.TestCase): sleep(.1) self.assertTrue(None is self.received_message) + def test_editedMessageHandler(self): + self._setup_updater('Test', edited=True) + d = self.updater.dispatcher + from telegram.ext import Filters + handler = MessageHandler([Filters.text], self.telegramHandlerEditedTest, allow_edited=True) + d.addHandler(handler) + self.updater.start_polling(0.01) + sleep(.1) + self.assertEqual(self.received_message, 'Test') + + # Remove handler + d.removeHandler(handler) + handler = MessageHandler([Filters.text], + self.telegramHandlerEditedTest, + allow_edited=False) + d.addHandler(handler) + self.reset() + + self.updater.bot.send_messages = 1 + sleep(.1) + self.assertTrue(None is self.received_message) + def test_addTelegramMessageHandlerMultipleMessages(self): self._setup_updater('Multiple', 100) self.updater.dispatcher.addHandler(MessageHandler([], self.telegramHandlerTest)) @@ -200,6 +226,25 @@ class UpdaterTest(BaseTest, unittest.TestCase): sleep(.1) self.assertTrue(None is self.received_message) + def test_editedCommandHandler(self): + self._setup_updater('/test', edited=True) + d = self.updater.dispatcher + handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=True) + d.addHandler(handler) + self.updater.start_polling(0.01) + sleep(.1) + self.assertEqual(self.received_message, '/test') + + # Remove handler + d.removeHandler(handler) + handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=False) + d.addHandler(handler) + self.reset() + + self.updater.bot.send_messages = 1 + sleep(.1) + self.assertTrue(None is self.received_message) + def test_addRemoveStringRegexHandler(self): self._setup_updater('', messages=0) d = self.updater.dispatcher @@ -612,7 +657,8 @@ class MockBot(object): messages=1, raise_error=False, bootstrap_retries=None, - bootstrap_err=TelegramError('test')): + bootstrap_err=TelegramError('test'), + edited=False): self.text = text self.send_messages = messages self.raise_error = raise_error @@ -620,13 +666,18 @@ class MockBot(object): self.bootstrap_retries = bootstrap_retries self.bootstrap_attempts = 0 self.bootstrap_err = bootstrap_err + self.edited = edited - @staticmethod - def mockUpdate(text): + def mockUpdate(self, text): message = Message(0, None, None, None) message.text = text update = Update(0) - update.message = message + + if self.edited: + update.edited_message = message + else: + update.message = message + return update def setWebhook(self, webhook_url=None, certificate=None):