Make filters and/or-able using bitwise operators.

See associated PR for more info.
This commit is contained in:
Jacob Bom 2016-09-14 19:29:15 +02:00
parent 5285f63e4a
commit 71e74da0a2
4 changed files with 176 additions and 64 deletions

View file

@ -26,7 +26,8 @@ from .choseninlineresulthandler import ChosenInlineResultHandler
from .commandhandler import CommandHandler
from .handler import Handler
from .inlinequeryhandler import InlineQueryHandler
from .messagehandler import MessageHandler, Filters
from .messagehandler import MessageHandler
from .filters import Filters
from .regexhandler import RegexHandler
from .stringcommandhandler import StringCommandHandler
from .stringregexhandler import StringRegexHandler

150
telegram/ext/filters.py Normal file
View file

@ -0,0 +1,150 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2016
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
""" This module contains the MessageHandler class """
class BaseFilter(object):
"""Base class for all Message Filters"""
def __call__(self, message):
raise NotImplementedError('Please implement a call method in your filter.')
def __and__(self, other):
return MergedFilter(self, and_filter=other)
def __or__(self, other):
return MergedFilter(self, or_filter=other)
class MergedFilter(BaseFilter):
"""Represents a filter consisting of two other filters."""
def __init__(self, base_filter, and_filter=None, or_filter=None):
self.base_filter = base_filter
self.and_filter = and_filter
self.or_filter = or_filter
def __call__(self, message):
if self.and_filter:
return self.base_filter(message) and self.and_filter(message)
elif self.or_filter:
return self.base_filter(message) or self.or_filter(message)
class Filters(object):
"""
Convenient namespace (class) & methods for the filter funcs of the
MessageHandler class.
"""
class Text(BaseFilter):
def __call__(self, message):
return bool(message.text and not message.text.startswith('/'))
text = Text()
class Command(BaseFilter):
def __call__(self, message):
return bool(message.text and message.text.startswith('/'))
command = Command()
class Audio(BaseFilter):
def __call__(self, message):
return bool(message.audio)
audio = Audio()
class Document(BaseFilter):
def __call__(self, message):
return bool(message.document)
document = Document()
class Photo(BaseFilter):
def __call__(self, message):
return bool(message.photo)
photo = Photo()
class Sticker(BaseFilter):
def __call__(self, message):
return bool(message.sticker)
sticker = Sticker()
class Video(BaseFilter):
def __call__(self, message):
return bool(message.video)
video = Video()
class Voice(BaseFilter):
def __call__(self, message):
return bool(message.voice)
voice = Voice()
class Contact(BaseFilter):
def __call__(self, message):
return bool(message.contact)
contact = Contact()
class Location(BaseFilter):
def __call__(self, message):
return bool(message.location)
location = Location()
class Venue(BaseFilter):
def __call__(self, message):
return bool(message.venue)
venue = Venue()
class StatusUpdate(BaseFilter):
def __call__(self, message):
return bool(message.new_chat_member or message.left_chat_member
or message.new_chat_title or message.new_chat_photo
or message.delete_chat_photo or message.group_chat_created
or message.supergroup_chat_created or message.channel_chat_created
or message.migrate_to_chat_id or message.migrate_from_chat_id
or message.pinned_message)
status_update = StatusUpdate()
class Forwarded(BaseFilter):
def __call__(self, message):
return bool(message.forward_date)
forwarded = Forwarded()

View file

@ -23,69 +23,6 @@ from telegram import Update
from telegram.utils.deprecate import deprecate
class Filters(object):
"""
Convenient namespace (class) & methods for the filter funcs of the
MessageHandler class.
"""
@staticmethod
def text(message):
return message.text and not message.text.startswith('/')
@staticmethod
def command(message):
return message.text and message.text.startswith('/')
@staticmethod
def audio(message):
return bool(message.audio)
@staticmethod
def document(message):
return bool(message.document)
@staticmethod
def photo(message):
return bool(message.photo)
@staticmethod
def sticker(message):
return bool(message.sticker)
@staticmethod
def video(message):
return bool(message.video)
@staticmethod
def voice(message):
return bool(message.voice)
@staticmethod
def contact(message):
return bool(message.contact)
@staticmethod
def location(message):
return bool(message.location)
@staticmethod
def venue(message):
return bool(message.venue)
@staticmethod
def status_update(message):
return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title
or message.new_chat_photo or message.delete_chat_photo
or message.group_chat_created or message.supergroup_chat_created
or message.channel_chat_created or message.migrate_to_chat_id
or message.migrate_from_chat_id or message.pinned_message)
@staticmethod
def forwarded(message):
return bool(message.forward_date)
class MessageHandler(Handler):
"""
Handler class to handle telegram messages. Messages are Telegram Updates

View file

@ -150,6 +150,30 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.assertTrue(Filters.status_update(self.message))
self.message.pinned_message = None
def test_and_filters(self):
# For now just test with forwarded as that's the only one that makes sense
# That'll change when we get a entities filter
self.message.text = 'test'
self.message.forward_date = True
self.assertTrue((Filters.text & Filters.forwarded)(self.message))
self.message.text = '/test'
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
self.message.text = 'test'
self.message.forward_date = None
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
def test_or_filters(self):
# For now just test with forwarded as that's the only one that makes sense
# That'll change when we get a entities filter
self.message.text = 'test'
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.group_chat_created = True
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.text = None
self.assertTrue((Filters.text | Filters.status_update)(self.message))
self.message.group_chat_created = False
self.assertFalse((Filters.text | Filters.status_update)(self.message))
if __name__ == '__main__':
unittest.main()