Merge pull request #411 from python-telegram-bot/bitwise-filters

Make filters and/or-able using bitwise operators.
This commit is contained in:
Jacob Bom 2016-10-19 11:40:35 +02:00 committed by GitHub
commit 4e5f4582dd
7 changed files with 303 additions and 100 deletions

View file

@ -0,0 +1,7 @@
telegram.ext.filters module
===========================
.. automodule:: telegram.ext.filters
:members:
:undoc-members:
:show-inheritance:

View file

@ -15,6 +15,7 @@ Submodules
telegram.ext.commandhandler
telegram.ext.inlinequeryhandler
telegram.ext.messagehandler
telegram.ext.filters
telegram.ext.regexhandler
telegram.ext.stringcommandhandler
telegram.ext.stringregexhandler

View file

@ -26,7 +26,8 @@ from .choseninlineresulthandler import ChosenInlineResultHandler
from .commandhandler import CommandHandler
from .handler import Handler
from .inlinequeryhandler import InlineQueryHandler
from .messagehandler import MessageHandler, Filters
from .messagehandler import MessageHandler
from .filters import BaseFilter, Filters
from .regexhandler import RegexHandler
from .stringcommandhandler import StringCommandHandler
from .stringregexhandler import StringRegexHandler
@ -35,5 +36,5 @@ from .conversationhandler import ConversationHandler
__all__ = ('Dispatcher', 'JobQueue', 'Job', 'Updater', 'CallbackQueryHandler',
'ChosenInlineResultHandler', 'CommandHandler', 'Handler', 'InlineQueryHandler',
'MessageHandler', 'Filters', 'RegexHandler', 'StringCommandHandler',
'MessageHandler', 'BaseFilter', 'Filters', 'RegexHandler', 'StringCommandHandler',
'StringRegexHandler', 'TypeHandler', 'ConversationHandler')

209
telegram/ext/filters.py Normal file
View file

@ -0,0 +1,209 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2016
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
""" This module contains the Filters for use with the MessageHandler class """
class BaseFilter(object):
"""Base class for all Message Filters
Subclassing from this class filters to be combined using bitwise operators:
And:
>>> (Filters.text & Filters.entity(MENTION))
Or:
>>> (Filters.audio | Filters.video)
Also works with more than two filters:
>>> (Filters.text & (Filters.entity(URL) | Filters.entity(TEXT_LINK)))
If you want to create your own filters create a class inheriting from this class and implement
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
otherwise. Note that the filters work only as class instances, not actual class objects
(so remember to initialize your filter classes).
"""
def __call__(self, message):
return self.filter(message)
def __and__(self, other):
return MergedFilter(self, and_filter=other)
def __or__(self, other):
return MergedFilter(self, or_filter=other)
def filter(self, message):
raise NotImplementedError
class MergedFilter(BaseFilter):
"""Represents a filter consisting of two other filters.
Args:
base_filter: Filter 1 of the merged filter
and_filter: Optional filter to "and" with base_filter. Mutually exclusive with or_filter.
or_filter: Optional filter to "or" with base_filter. Mutually exclusive with and_filter.
"""
def __init__(self, base_filter, and_filter=None, or_filter=None):
self.base_filter = base_filter
self.and_filter = and_filter
self.or_filter = or_filter
def filter(self, message):
if self.and_filter:
return self.base_filter(message) and self.and_filter(message)
elif self.or_filter:
return self.base_filter(message) or self.or_filter(message)
def __str__(self):
return ("<telegram.ext.filters.MergedFilter consisting of"
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)
__repr__ = __str__
class Filters(object):
"""
Predefined filters for use with the `filter` argument of :class:`telegram.ext.MessageHandler`.
"""
class _All(BaseFilter):
def filter(self, message):
return True
all = _All()
class _Text(BaseFilter):
def filter(self, message):
return bool(message.text and not message.text.startswith('/'))
text = _Text()
class _Command(BaseFilter):
def filter(self, message):
return bool(message.text and message.text.startswith('/'))
command = _Command()
class _Audio(BaseFilter):
def filter(self, message):
return bool(message.audio)
audio = _Audio()
class _Document(BaseFilter):
def filter(self, message):
return bool(message.document)
document = _Document()
class _Photo(BaseFilter):
def filter(self, message):
return bool(message.photo)
photo = _Photo()
class _Sticker(BaseFilter):
def filter(self, message):
return bool(message.sticker)
sticker = _Sticker()
class _Video(BaseFilter):
def filter(self, message):
return bool(message.video)
video = _Video()
class _Voice(BaseFilter):
def filter(self, message):
return bool(message.voice)
voice = _Voice()
class _Contact(BaseFilter):
def filter(self, message):
return bool(message.contact)
contact = _Contact()
class _Location(BaseFilter):
def filter(self, message):
return bool(message.location)
location = _Location()
class _Venue(BaseFilter):
def filter(self, message):
return bool(message.venue)
venue = _Venue()
class _StatusUpdate(BaseFilter):
def filter(self, message):
return bool(message.new_chat_member or message.left_chat_member
or message.new_chat_title or message.new_chat_photo
or message.delete_chat_photo or message.group_chat_created
or message.supergroup_chat_created or message.channel_chat_created
or message.migrate_to_chat_id or message.migrate_from_chat_id
or message.pinned_message)
status_update = _StatusUpdate()
class _Forwarded(BaseFilter):
def filter(self, message):
return bool(message.forward_date)
forwarded = _Forwarded()
class entity(BaseFilter):
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`.
Args:
entity_type: Entity type to check for. All types can be found as constants
in :class:`telegram.MessageEntity`.
Returns: function to use as filter
"""
def __init__(self, entity_type):
self.entity_type = entity_type
def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])

View file

@ -17,92 +17,13 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
""" This module contains the MessageHandler class """
import warnings
from .handler import Handler
from telegram import Update
from telegram.utils.deprecate import deprecate
class Filters(object):
"""
Convenient namespace (class) & methods for the filter funcs of the
MessageHandler class.
"""
@staticmethod
def text(message):
return message.text and not message.text.startswith('/')
@staticmethod
def command(message):
return message.text and message.text.startswith('/')
@staticmethod
def audio(message):
return bool(message.audio)
@staticmethod
def document(message):
return bool(message.document)
@staticmethod
def photo(message):
return bool(message.photo)
@staticmethod
def sticker(message):
return bool(message.sticker)
@staticmethod
def video(message):
return bool(message.video)
@staticmethod
def voice(message):
return bool(message.voice)
@staticmethod
def contact(message):
return bool(message.contact)
@staticmethod
def location(message):
return bool(message.location)
@staticmethod
def venue(message):
return bool(message.venue)
@staticmethod
def status_update(message):
return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title
or message.new_chat_photo or message.delete_chat_photo
or message.group_chat_created or message.supergroup_chat_created
or message.channel_chat_created or message.migrate_to_chat_id
or message.migrate_from_chat_id or message.pinned_message)
@staticmethod
def forwarded(message):
return bool(message.forward_date)
@staticmethod
def entity(entity_type):
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`.
Args:
entity_type: Entity type to check for. All types can be found as constants
in :class:`telegram.MessageEntity`.
Returns: function to use as filter
"""
def entities_filter(message):
return any([entity.type == entity_type for entity in message.entities])
return entities_filter
class MessageHandler(Handler):
"""
Handler class to handle telegram messages. Messages are Telegram Updates
@ -110,12 +31,10 @@ class MessageHandler(Handler):
updates.
Args:
filters (list[function]): A list of filter functions. Standard filters
can be found in the Filters class above.
| Each `function` takes ``Update`` as arg and returns ``bool``.
| All messages that match at least one of those filters will be
accepted. If ``bool(filters)`` evaluates to ``False``, messages are
not filtered.
filters (telegram.ext.BaseFilter): A filter inheriting from
:class:`telegram.filters.BaseFilter`. Standard filters can be found in
:class:`telegram.filters.Filters`. Filters can be combined using bitwise
operators (& for and, | for or).
callback (function): A function that takes ``bot, update`` as
positional arguments. It will be called when the ``check_update``
has determined that an update should be processed by this handler.
@ -137,6 +56,13 @@ class MessageHandler(Handler):
self.filters = filters
self.allow_edited = allow_edited
# We put this up here instead of with the rest of checking code
# in check_update since we don't wanna spam a ton
if isinstance(self.filters, list):
warnings.warn('Using a list of filters in MessageHandler is getting '
'deprecated, please use bitwise operators (& and |) '
'instead. More info: https://git.io/vPTbc.')
def check_update(self, update):
if (isinstance(update, Update)
and (update.message or update.edited_message and self.allow_edited)):
@ -146,7 +72,10 @@ class MessageHandler(Handler):
else:
message = update.message or update.edited_message
res = any(func(message) for func in self.filters)
if isinstance(self.filters, list):
res = any(func(message) for func in self.filters)
else:
res = self.filters(message)
else:
res = False

View file

@ -17,7 +17,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""
This module contains a object that represents Tests for MessageHandler.Filters
This module contains a object that represents Tests for Filters for use with MessageHandler
"""
import sys
@ -28,7 +28,7 @@ import functools
sys.path.append('.')
from telegram import Message, User, Chat, MessageEntity
from telegram.ext import Filters
from telegram.ext import Filters, BaseFilter
from tests.base import BaseTest
@ -37,6 +37,7 @@ class FiltersTest(BaseTest, unittest.TestCase):
def setUp(self):
self.message = Message(0, User(0, "Testuser"), datetime.now(), Chat(0, 'private'))
self.e = functools.partial(MessageEntity, offset=0, length=0)
def test_filters_text(self):
self.message.text = 'test'
@ -152,20 +153,76 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.pinned_message = None
def test_entities_filter(self):
e = functools.partial(MessageEntity, offset=0, length=0)
self.message.entities = [e(MessageEntity.MENTION)]
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = []
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = [e(MessageEntity.BOLD)]
self.message.entities = [self.e(MessageEntity.BOLD)]
self.assertFalse(Filters.entity(MessageEntity.MENTION)(self.message))
self.message.entities = [e(MessageEntity.BOLD), e(MessageEntity.MENTION)]
self.message.entities = [self.e(MessageEntity.BOLD), self.e(MessageEntity.MENTION)]
self.assertTrue(Filters.entity(MessageEntity.MENTION)(self.message))
def test_and_filters(self):
self.message.text = 'test'
self.message.forward_date = True
self.assertTrue((Filters.text & Filters.forwarded)(self.message))
self.message.text = '/test'
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
self.message.text = 'test'
self.message.forward_date = None
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
self.message.text = 'test'
self.message.forward_date = True
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue((Filters.text & Filters.forwarded & Filters.entity(MessageEntity.MENTION))(
self.message))
self.message.entities = [self.e(MessageEntity.BOLD)]
self.assertFalse((Filters.text & Filters.forwarded & Filters.entity(MessageEntity.MENTION)
)(self.message))
def test_or_filters(self):
self.message.text = 'test'
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.group_chat_created = True
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.text = None
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.group_chat_created = False
self.assertFalse((Filters.text | Filters.status_update)(self.message))
def test_and_or_filters(self):
self.message.text = 'test'
self.message.forward_date = True
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))
self.message.forward_date = False
self.assertFalse((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)
))(self.message))
self.message.entities = [self.e(MessageEntity.MENTION)]
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))
self.assertRegexpMatches(
str((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)))),
r"<telegram.ext.filters.MergedFilter consisting of <telegram.ext.filters.(Filters.)?_"
r"Text object at .*?> and <telegram.ext.filters.MergedFilter consisting of "
r"<telegram.ext.filters.(Filters.)?_Forwarded object at .*?> or "
r"<telegram.ext.filters.(Filters.)?entity object at .*?>>>")
def test_faulty_custom_filter(self):
class _CustomFilter(BaseFilter):
pass
custom = _CustomFilter()
with self.assertRaises(NotImplementedError):
(custom & Filters.text)(self.message)
if __name__ == '__main__':
unittest.main()

View file

@ -182,7 +182,7 @@ class UpdaterTest(BaseTest, unittest.TestCase):
self._setup_updater('Test', edited=True)
d = self.updater.dispatcher
from telegram.ext import Filters
handler = MessageHandler([Filters.text], self.telegramHandlerEditedTest, allow_edited=True)
handler = MessageHandler(Filters.text, self.telegramHandlerEditedTest, allow_edited=True)
d.addHandler(handler)
self.updater.start_polling(0.01)
sleep(.1)
@ -190,8 +190,7 @@ class UpdaterTest(BaseTest, unittest.TestCase):
# Remove handler
d.removeHandler(handler)
handler = MessageHandler(
[Filters.text], self.telegramHandlerEditedTest, allow_edited=False)
handler = MessageHandler(Filters.text, self.telegramHandlerEditedTest, allow_edited=False)
d.addHandler(handler)
self.reset()
@ -201,7 +200,7 @@ class UpdaterTest(BaseTest, unittest.TestCase):
def test_addTelegramMessageHandlerMultipleMessages(self):
self._setup_updater('Multiple', 100)
self.updater.dispatcher.add_handler(MessageHandler([], self.telegramHandlerTest))
self.updater.dispatcher.add_handler(MessageHandler(Filters.all, self.telegramHandlerTest))
self.updater.start_polling(0.0)
sleep(2)
self.assertEqual(self.received_message, 'Multiple')