add allow_edited parameter to MessageHandler and CommandHandler

This commit is contained in:
Jannes Höke 2016-05-27 11:06:29 +02:00
parent 9a13de4a96
commit a0bb5730c6
3 changed files with 116 additions and 41 deletions

View file

@ -34,6 +34,8 @@ class CommandHandler(Handler):
callback (function): A function that takes ``bot, update`` as
positional arguments. It will be called when the ``check_update``
has determined that an update should be processed by this handler.
allow_edited (Optional[bool]): If the handler should also accept edited messages.
Default is ``False``
pass_args (optional[bool]): If the handler should be passed the
arguments passed to the command as a keyword argument called `
``args``. It will contain a list of strings, which is the text
@ -43,21 +45,35 @@ class CommandHandler(Handler):
be used to insert updates. Default is ``False``
"""
def __init__(self, command, callback, pass_args=False, pass_update_queue=False):
def __init__(self,
command,
callback,
allow_edited=False,
pass_args=False,
pass_update_queue=False):
super(CommandHandler, self).__init__(callback, pass_update_queue)
self.command = command
self.allow_edited = allow_edited
self.pass_args = pass_args
def check_update(self, update):
return (isinstance(update, Update) and update.message and update.message.text
and update.message.text.startswith('/')
and update.message.text[1:].split(' ')[0].split('@')[0] == self.command)
if (isinstance(update, Update)
and (update.message or update.edited_message and self.allow_edited)):
message = update.message or update.edited_message
return (message.text and message.text.startswith('/')
and message.text[1:].split(' ')[0].split('@')[0] == self.command)
else:
return False
def handle_update(self, update, dispatcher):
optional_args = self.collect_optional_args(dispatcher)
message = update.message or update.edited_message
if self.pass_args:
optional_args['args'] = update.message.text.split(' ')[1:]
optional_args['args'] = message.text.split(' ')[1:]
self.callback(dispatcher.bot, update, **optional_args)

View file

@ -30,57 +30,56 @@ class Filters(object):
"""
@staticmethod
def text(update):
return update.message.text and not update.message.text.startswith('/')
def text(message):
return message.text and not message.text.startswith('/')
@staticmethod
def command(update):
return update.message.text and update.message.text.startswith('/')
def command(message):
return message.text and message.text.startswith('/')
@staticmethod
def audio(update):
return bool(update.message.audio)
def audio(message):
return bool(message.audio)
@staticmethod
def document(update):
return bool(update.message.document)
def document(message):
return bool(message.document)
@staticmethod
def photo(update):
return bool(update.message.photo)
def photo(message):
return bool(message.photo)
@staticmethod
def sticker(update):
return bool(update.message.sticker)
def sticker(message):
return bool(message.sticker)
@staticmethod
def video(update):
return bool(update.message.video)
def video(message):
return bool(message.video)
@staticmethod
def voice(update):
return bool(update.message.voice)
def voice(message):
return bool(message.voice)
@staticmethod
def contact(update):
return bool(update.message.contact)
def contact(message):
return bool(message.contact)
@staticmethod
def location(update):
return bool(update.message.location)
def location(message):
return bool(message.location)
@staticmethod
def venue(update):
return bool(update.message.venue)
def venue(message):
return bool(message.venue)
@staticmethod
def status_update(update):
return bool(update.message.new_chat_member or update.message.left_chat_member
or update.message.new_chat_title or update.message.new_chat_photo
or update.message.delete_chat_photo or update.message.group_chat_created
or update.message.supergroup_chat_created
or update.message.channel_chat_created or update.message.migrate_to_chat_id
or update.message.migrate_from_chat_id or update.message.pinned_message)
def status_update(message):
return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title
or message.new_chat_photo or message.delete_chat_photo
or message.group_chat_created or message.supergroup_chat_created
or message.channel_chat_created or message.migrate_to_chat_id
or message.migrate_from_chat_id or message.pinned_message)
class MessageHandler(Handler):
@ -99,23 +98,32 @@ class MessageHandler(Handler):
callback (function): A function that takes ``bot, update`` as
positional arguments. It will be called when the ``check_update``
has determined that an update should be processed by this handler.
allow_edited (Optional[bool]): If the handler should also accept edited messages.
Default is ``False``
pass_update_queue (optional[bool]): If the handler should be passed the
update queue as a keyword argument called ``update_queue``. It can
be used to insert updates. Default is ``False``
"""
def __init__(self, filters, callback, pass_update_queue=False):
def __init__(self, filters, callback, allow_edited=False, pass_update_queue=False):
super(MessageHandler, self).__init__(callback, pass_update_queue)
self.filters = filters
self.allow_edited = allow_edited
def check_update(self, update):
if isinstance(update, Update) and update.message:
if (isinstance(update, Update)
and (update.message or update.edited_message and self.allow_edited)):
if not self.filters:
res = True
else:
res = any(func(update) for func in self.filters)
message = update.message or update.edited_message
res = any(func(message) for func in self.filters)
else:
res = False
return res
def handle_update(self, update, dispatcher):

View file

@ -93,6 +93,10 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self.received_message = update.message.text
self.message_count += 1
def telegramHandlerEditedTest(self, bot, update):
self.received_message = update.edited_message.text
self.message_count += 1
def telegramInlineHandlerTest(self, bot, update):
self.received_message = (update.inline_query, update.chosen_inline_result)
self.message_count += 1
@ -157,6 +161,28 @@ class UpdaterTest(BaseTest, unittest.TestCase):
sleep(.1)
self.assertTrue(None is self.received_message)
def test_editedMessageHandler(self):
self._setup_updater('Test', edited=True)
d = self.updater.dispatcher
from telegram.ext import Filters
handler = MessageHandler([Filters.text], self.telegramHandlerEditedTest, allow_edited=True)
d.addHandler(handler)
self.updater.start_polling(0.01)
sleep(.1)
self.assertEqual(self.received_message, 'Test')
# Remove handler
d.removeHandler(handler)
handler = MessageHandler([Filters.text],
self.telegramHandlerEditedTest,
allow_edited=False)
d.addHandler(handler)
self.reset()
self.updater.bot.send_messages = 1
sleep(.1)
self.assertTrue(None is self.received_message)
def test_addTelegramMessageHandlerMultipleMessages(self):
self._setup_updater('Multiple', 100)
self.updater.dispatcher.addHandler(MessageHandler([], self.telegramHandlerTest))
@ -200,6 +226,25 @@ class UpdaterTest(BaseTest, unittest.TestCase):
sleep(.1)
self.assertTrue(None is self.received_message)
def test_editedCommandHandler(self):
self._setup_updater('/test', edited=True)
d = self.updater.dispatcher
handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=True)
d.addHandler(handler)
self.updater.start_polling(0.01)
sleep(.1)
self.assertEqual(self.received_message, '/test')
# Remove handler
d.removeHandler(handler)
handler = CommandHandler('test', self.telegramHandlerEditedTest, allow_edited=False)
d.addHandler(handler)
self.reset()
self.updater.bot.send_messages = 1
sleep(.1)
self.assertTrue(None is self.received_message)
def test_addRemoveStringRegexHandler(self):
self._setup_updater('', messages=0)
d = self.updater.dispatcher
@ -612,7 +657,8 @@ class MockBot(object):
messages=1,
raise_error=False,
bootstrap_retries=None,
bootstrap_err=TelegramError('test')):
bootstrap_err=TelegramError('test'),
edited=False):
self.text = text
self.send_messages = messages
self.raise_error = raise_error
@ -620,13 +666,18 @@ class MockBot(object):
self.bootstrap_retries = bootstrap_retries
self.bootstrap_attempts = 0
self.bootstrap_err = bootstrap_err
self.edited = edited
@staticmethod
def mockUpdate(text):
def mockUpdate(self, text):
message = Message(0, None, None, None)
message.text = text
update = Update(0)
update.message = message
if self.edited:
update.edited_message = message
else:
update.message = message
return update
def setWebhook(self, webhook_url=None, certificate=None):