diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 5ea1ee491..d7b5dba00 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -22,7 +22,7 @@ import re from future.utils import string_types -from telegram import Chat, Update +from telegram import Chat, Update, MessageEntity __all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter'] @@ -249,10 +249,7 @@ class Filters(object): def __call__(self, update): if isinstance(update, Update): - if self.update_filter: - return self.filter(update) - else: - return self.filter(update.effective_message) + return self.filter(update.effective_message) else: return self._TextIterable(update) @@ -296,10 +293,7 @@ class Filters(object): def __call__(self, update): if isinstance(update, Update): - if self.update_filter: - return self.filter(update) - else: - return self.filter(update.effective_message) + return self.filter(update.effective_message) else: return self._CaptionIterable(update) @@ -321,11 +315,41 @@ class Filters(object): class _Command(BaseFilter): name = 'Filters.command' + class _CommandOnlyStart(BaseFilter): + + def __init__(self, only_start): + self.only_start = only_start + self.name = 'Filters.command({})'.format(only_start) + + def filter(self, message): + return (message.entities + and any([e.type == MessageEntity.BOT_COMMAND for e in message.entities])) + + def __call__(self, update): + if isinstance(update, Update): + return self.filter(update.effective_message) + else: + return self._CommandOnlyStart(update) + def filter(self, message): - return bool(message.text and message.text.startswith('/')) + return (message.entities and message.entities[0].type == MessageEntity.BOT_COMMAND + and message.entities[0].offset == 0) command = _Command() - """Messages starting with ``/``.""" + """ + Messages with a :attr:`telegram.MessageEntity.BOT_COMMAND`. By default only allows + messages `starting` with a bot command. Pass ``False`` to also allow messages that contain a + bot command `anywhere` in the text. + + Examples:: + + MessageHandler(Filters.command, command_at_start_callback) + MessageHandler(Filters.command(False), command_anywhere_callback) + + Args: + update (:obj:`bool`, optional): Whether to only allow messages that `start` with a bot + command. Defaults to ``True``. + """ class regex(BaseFilter): """ diff --git a/tests/test_filters.py b/tests/test_filters.py index e7f1515c5..55ee67f25 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -63,12 +63,23 @@ class TestFilters(object): assert Filters.caption({'test', 'test1'})(update) assert not Filters.caption(['test1', 'test2'])(update) - def test_filters_command(self, update): + def test_filters_command_default(self, update): update.message.text = 'test' assert not Filters.command(update) update.message.text = '/test' + assert not Filters.command(update) + # Only accept commands at the beginning + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 3, 5)] + assert not Filters.command(update) + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] assert Filters.command(update) + def test_filters_command_anywhere(self, update): + update.message.text = 'test /cmd' + assert not (Filters.command(False))(update) + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 5, 4)] + assert (Filters.command(False))(update) + def test_filters_regex(self, update): SRE_TYPE = type(re.match("", "")) update.message.text = '/start deep-linked param' @@ -120,6 +131,7 @@ class TestFilters(object): def test_filters_merged_with_regex(self, update): SRE_TYPE = type(re.match("", "")) update.message.text = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] result = (Filters.command & Filters.regex(r'linked param'))(update) assert result assert isinstance(result, dict) @@ -216,6 +228,7 @@ class TestFilters(object): result = filter(update) assert not result update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] result = filter(update) assert result assert isinstance(result, bool) @@ -230,6 +243,7 @@ class TestFilters(object): def test_regex_inverted(self, update): update.message.text = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] filter = ~Filters.regex(r'deep-linked param') result = filter(update) assert not result @@ -243,6 +257,7 @@ class TestFilters(object): result = filter(update) assert not result update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] result = filter(update) assert result update.message.text = '/linked' @@ -251,15 +266,18 @@ class TestFilters(object): filter = (~Filters.regex('linked') | Filters.command) update.message.text = "it's linked" + update.message.entities = [] result = filter(update) assert not result update.message.text = '/start linked' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] result = filter(update) assert result update.message.text = '/start' result = filter(update) assert result update.message.text = 'nothig' + update.message.entities = [] result = filter(update) assert result @@ -664,14 +682,17 @@ class TestFilters(object): def test_inverted_filters(self, update): update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] assert Filters.command(update) assert not (~Filters.command)(update) update.message.text = 'test' + update.message.entities = [] assert not Filters.command(update) assert (~Filters.command)(update) def test_inverted_and_filters(self, update): update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] update.message.forward_date = 1 assert (Filters.forwarded & Filters.command)(update) assert not (~Filters.forwarded & Filters.command)(update) @@ -683,6 +704,7 @@ class TestFilters(object): assert not (Filters.forwarded & ~Filters.command)(update) assert (~(Filters.forwarded & Filters.command))(update) update.message.text = 'test' + update.message.entities = [] assert not (Filters.forwarded & Filters.command)(update) assert not (~Filters.forwarded & Filters.command)(update) assert not (Filters.forwarded & ~Filters.command)(update) @@ -746,6 +768,7 @@ class TestFilters(object): def test_merged_short_circuit_and(self, update): update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] class TestException(Exception): pass @@ -760,6 +783,7 @@ class TestFilters(object): (Filters.command & raising_filter)(update) update.message.text = 'test' + update.message.entities = [] (Filters.command & raising_filter)(update) def test_merged_short_circuit_or(self, update): @@ -778,10 +802,12 @@ class TestFilters(object): (Filters.command | raising_filter)(update) update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] (Filters.command | raising_filter)(update) def test_merged_data_merging_and(self, update): update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] class DataFilter(BaseFilter): data_filter = True @@ -799,6 +825,7 @@ class TestFilters(object): assert result['test'] == ['blah1', 'blah2'] update.message.text = 'test' + update.message.entities = [] result = (Filters.command & DataFilter('blah'))(update) assert not result