add allow_edited parameter to MessageHandler and CommandHandler

This commit is contained in:
Jannes Höke 2016-05-27 11:06:29 +02:00
parent 9a13de4a96
commit a0bb5730c6
3 changed files with 116 additions and 41 deletions

View file

@ -34,6 +34,8 @@ class CommandHandler(Handler):
callback (function): A function that takes ``bot, update`` as callback (function): A function that takes ``bot, update`` as
positional arguments. It will be called when the ``check_update`` positional arguments. It will be called when the ``check_update``
has determined that an update should be processed by this handler. has determined that an update should be processed by this handler.
allow_edited (Optional[bool]): If the handler should also accept edited messages.
Default is ``False``
pass_args (optional[bool]): If the handler should be passed the pass_args (optional[bool]): If the handler should be passed the
arguments passed to the command as a keyword argument called ` arguments passed to the command as a keyword argument called `
``args``. It will contain a list of strings, which is the text ``args``. It will contain a list of strings, which is the text
@ -43,21 +45,35 @@ class CommandHandler(Handler):
be used to insert updates. Default is ``False`` be used to insert updates. Default is ``False``
""" """
def __init__(self, command, callback, pass_args=False, pass_update_queue=False): def __init__(self,
command,
callback,
allow_edited=False,
pass_args=False,
pass_update_queue=False):
super(CommandHandler, self).__init__(callback, pass_update_queue) super(CommandHandler, self).__init__(callback, pass_update_queue)
self.command = command self.command = command
self.allow_edited = allow_edited
self.pass_args = pass_args self.pass_args = pass_args
def check_update(self, update): def check_update(self, update):
return (isinstance(update, Update) and update.message and update.message.text if (isinstance(update, Update)
and update.message.text.startswith('/') and (update.message or update.edited_message and self.allow_edited)):
and update.message.text[1:].split(' ')[0].split('@')[0] == self.command) message = update.message or update.edited_message
return (message.text and message.text.startswith('/')
and message.text[1:].split(' ')[0].split('@')[0] == self.command)
else:
return False
def handle_update(self, update, dispatcher): def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher) optional_args = self.collect_optional_args(dispatcher)
message = update.message or update.edited_message
if self.pass_args: if self.pass_args:
optional_args['args'] = update.message.text.split(' ')[1:] optional_args['args'] = message.text.split(' ')[1:]
self.callback(dispatcher.bot, update, **optional_args) self.callback(dispatcher.bot, update, **optional_args)

View file

@ -30,57 +30,56 @@ class Filters(object):
""" """
@staticmethod @staticmethod
def text(update): def text(message):
return update.message.text and not update.message.text.startswith('/') return message.text and not message.text.startswith('/')
@staticmethod @staticmethod
def command(update): def command(message):
return update.message.text and update.message.text.startswith('/') return message.text and message.text.startswith('/')
@staticmethod @staticmethod
def audio(update): def audio(message):
return bool(update.message.audio) return bool(message.audio)
@staticmethod @staticmethod
def document(update): def document(message):
return bool(update.message.document) return bool(message.document)
@staticmethod @staticmethod
def photo(update): def photo(message):
return bool(update.message.photo) return bool(message.photo)
@staticmethod @staticmethod
def sticker(update): def sticker(message):
return bool(update.message.sticker) return bool(message.sticker)
@staticmethod @staticmethod
def video(update): def video(message):
return bool(update.message.video) return bool(message.video)
@staticmethod @staticmethod
def voice(update): def voice(message):
return bool(update.message.voice) return bool(message.voice)
@staticmethod @staticmethod
def contact(update): def contact(message):
return bool(update.message.contact) return bool(message.contact)
@staticmethod @staticmethod
def location(update): def location(message):
return bool(update.message.location) return bool(message.location)
@staticmethod @staticmethod
def venue(update): def venue(message):
return bool(update.message.venue) return bool(message.venue)
@staticmethod @staticmethod
def status_update(update): def status_update(message):
return bool(update.message.new_chat_member or update.message.left_chat_member return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title
or update.message.new_chat_title or update.message.new_chat_photo or message.new_chat_photo or message.delete_chat_photo
or update.message.delete_chat_photo or update.message.group_chat_created or message.group_chat_created or message.supergroup_chat_created
or update.message.supergroup_chat_created or message.channel_chat_created or message.migrate_to_chat_id
or update.message.channel_chat_created or update.message.migrate_to_chat_id or message.migrate_from_chat_id or message.pinned_message)
or update.message.migrate_from_chat_id or update.message.pinned_message)
class MessageHandler(Handler): class MessageHandler(Handler):
@ -99,23 +98,32 @@ class MessageHandler(Handler):
callback (function): A function that takes ``bot, update`` as callback (function): A function that takes ``bot, update`` as
positional arguments. It will be called when the ``check_update`` positional arguments. It will be called when the ``check_update``
has determined that an update should be processed by this handler. has determined that an update should be processed by this handler.
allow_edited (Optional[bool]): If the handler should also accept edited messages.
Default is ``False``
pass_update_queue (optional[bool]): If the handler should be passed the pass_update_queue (optional[bool]): If the handler should be passed the
update queue as a keyword argument called ``update_queue``. It can update queue as a keyword argument called ``update_queue``. It can
be used to insert updates. Default is ``False`` be used to insert updates. Default is ``False``
""" """
def __init__(self, filters, callback, pass_update_queue=False): def __init__(self, filters, callback, allow_edited=False, pass_update_queue=False):
super(MessageHandler, self).__init__(callback, pass_update_queue) super(MessageHandler, self).__init__(callback, pass_update_queue)
self.filters = filters self.filters = filters
self.allow_edited = allow_edited
def check_update(self, update): def check_update(self, update):
if isinstance(update, Update) and update.message: if (isinstance(update, Update)
and (update.message or update.edited_message and self.allow_edited)):
if not self.filters: if not self.filters:
res = True res = True
else: else:
res = any(func(update) for func in self.filters) message = update.message or update.edited_message
res = any(func(message) for func in self.filters)
else: else:
res = False res = False
return res return res
def handle_update(self, update, dispatcher): def handle_update(self, update, dispatcher):

View file

@ -93,6 +93,10 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.received_message = update.message.text self.received_message = update.message.text
self.message_count += 1 self.message_count += 1
def telegramHandlerEditedTest(self, bot, update):
self.received_message = update.edited_message.text
self.message_count += 1
def telegramInlineHandlerTest(self, bot, update): def telegramInlineHandlerTest(self, bot, update):
self.received_message = (update.inline_query, update.chosen_inline_result) self.received_message = (update.inline_query, update.chosen_inline_result)
self.message_count += 1 self.message_count += 1
@ -157,6 +161,28 @@ class UpdaterTest(BaseTest, unittest.TestCase):
sleep(.1) sleep(.1)
self.assertTrue(None is self.received_message) self.assertTrue(None is self.received_message)
def test_editedMessageHandler(self):
self._setup_updater('Test', edited=True)
d = self.updater.dispatcher
from telegram.ext import Filters
handler = MessageHandler([Filters.text], self.telegramHandlerEditedTest, allow_edited=True)
d.addHandler(handler)
self.updater.start_polling(0.01)
sleep(.1)
self.assertEqual(self.received_message, 'Test')
# Remove handler
d.removeHandler(handler)
handler = MessageHandler([Filters.text],
self.telegramHandlerEditedTest,
allow_edited=False)
d.addHandler(handler)
self.reset()
self.updater.bot.send_messages = 1
sleep(.1)
self.assertTrue(None is self.received_message)
def test_addTelegramMessageHandlerMultipleMessages(self): def test_addTelegramMessageHandlerMultipleMessages(self):
self._setup_updater('Multiple', 100) self._setup_updater('Multiple', 100)
self.updater.dispatcher.addHandler(MessageHandler([], self.telegramHandlerTest)) self.updater.dispatcher.addHandler(MessageHandler([], self.telegramHandlerTest))
@ -200,6 +226,25 @@ class UpdaterTest(BaseTest, unittest.TestCase):
sleep(.1) sleep(.1)
self.assertTrue(None is self.received_message) self.assertTrue(None is self.received_message)
def test_editedCommandHandler(self):
self._setup_updater('/test', edited=True)
d = self.updater.dispatcher
handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=True)
d.addHandler(handler)
self.updater.start_polling(0.01)
sleep(.1)
self.assertEqual(self.received_message, '/test')
# Remove handler
d.removeHandler(handler)
handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=False)
d.addHandler(handler)
self.reset()
self.updater.bot.send_messages = 1
sleep(.1)
self.assertTrue(None is self.received_message)
def test_addRemoveStringRegexHandler(self): def test_addRemoveStringRegexHandler(self):
self._setup_updater('', messages=0) self._setup_updater('', messages=0)
d = self.updater.dispatcher d = self.updater.dispatcher
@ -612,7 +657,8 @@ class MockBot(object):
messages=1, messages=1,
raise_error=False, raise_error=False,
bootstrap_retries=None, bootstrap_retries=None,
bootstrap_err=TelegramError('test')): bootstrap_err=TelegramError('test'),
edited=False):
self.text = text self.text = text
self.send_messages = messages self.send_messages = messages
self.raise_error = raise_error self.raise_error = raise_error
@ -620,13 +666,18 @@ class MockBot(object):
self.bootstrap_retries = bootstrap_retries self.bootstrap_retries = bootstrap_retries
self.bootstrap_attempts = 0 self.bootstrap_attempts = 0
self.bootstrap_err = bootstrap_err self.bootstrap_err = bootstrap_err
self.edited = edited
@staticmethod def mockUpdate(self, text):
def mockUpdate(text):
message = Message(0, None, None, None) message = Message(0, None, None, None)
message.text = text message.text = text
update = Update(0) update = Update(0)
if self.edited:
update.edited_message = message
else:
update.message = message update.message = message
return update return update
def setWebhook(self, webhook_url=None, certificate=None): def setWebhook(self, webhook_url=None, certificate=None):