mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-16 12:25:45 +01:00
Make filters and/or-able using bitwise operators.
See associated PR for more info.
This commit is contained in:
parent
5285f63e4a
commit
71e74da0a2
4 changed files with 176 additions and 64 deletions
|
@ -26,7 +26,8 @@ from .choseninlineresulthandler import ChosenInlineResultHandler
|
|||
from .commandhandler import CommandHandler
|
||||
from .handler import Handler
|
||||
from .inlinequeryhandler import InlineQueryHandler
|
||||
from .messagehandler import MessageHandler, Filters
|
||||
from .messagehandler import MessageHandler
|
||||
from .filters import Filters
|
||||
from .regexhandler import RegexHandler
|
||||
from .stringcommandhandler import StringCommandHandler
|
||||
from .stringregexhandler import StringRegexHandler
|
||||
|
|
150
telegram/ext/filters.py
Normal file
150
telegram/ext/filters.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# A library that provides a Python interface to the Telegram Bot API
|
||||
# Copyright (C) 2015-2016
|
||||
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
""" This module contains the MessageHandler class """
|
||||
|
||||
|
||||
class BaseFilter(object):
|
||||
"""Base class for all Message Filters"""
|
||||
|
||||
def __call__(self, message):
|
||||
raise NotImplementedError('Please implement a call method in your filter.')
|
||||
|
||||
def __and__(self, other):
|
||||
return MergedFilter(self, and_filter=other)
|
||||
|
||||
def __or__(self, other):
|
||||
return MergedFilter(self, or_filter=other)
|
||||
|
||||
|
||||
class MergedFilter(BaseFilter):
|
||||
"""Represents a filter consisting of two other filters."""
|
||||
|
||||
def __init__(self, base_filter, and_filter=None, or_filter=None):
|
||||
self.base_filter = base_filter
|
||||
self.and_filter = and_filter
|
||||
self.or_filter = or_filter
|
||||
|
||||
def __call__(self, message):
|
||||
if self.and_filter:
|
||||
return self.base_filter(message) and self.and_filter(message)
|
||||
elif self.or_filter:
|
||||
return self.base_filter(message) or self.or_filter(message)
|
||||
|
||||
|
||||
class Filters(object):
|
||||
"""
|
||||
Convenient namespace (class) & methods for the filter funcs of the
|
||||
MessageHandler class.
|
||||
"""
|
||||
|
||||
class Text(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.text and not message.text.startswith('/'))
|
||||
|
||||
text = Text()
|
||||
|
||||
class Command(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.text and message.text.startswith('/'))
|
||||
|
||||
command = Command()
|
||||
|
||||
class Audio(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.audio)
|
||||
|
||||
audio = Audio()
|
||||
|
||||
class Document(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.document)
|
||||
|
||||
document = Document()
|
||||
|
||||
class Photo(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.photo)
|
||||
|
||||
photo = Photo()
|
||||
|
||||
class Sticker(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.sticker)
|
||||
|
||||
sticker = Sticker()
|
||||
|
||||
class Video(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.video)
|
||||
|
||||
video = Video()
|
||||
|
||||
class Voice(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.voice)
|
||||
|
||||
voice = Voice()
|
||||
|
||||
class Contact(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.contact)
|
||||
|
||||
contact = Contact()
|
||||
|
||||
class Location(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.location)
|
||||
|
||||
location = Location()
|
||||
|
||||
class Venue(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.venue)
|
||||
|
||||
venue = Venue()
|
||||
|
||||
class StatusUpdate(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.new_chat_member or message.left_chat_member
|
||||
or message.new_chat_title or message.new_chat_photo
|
||||
or message.delete_chat_photo or message.group_chat_created
|
||||
or message.supergroup_chat_created or message.channel_chat_created
|
||||
or message.migrate_to_chat_id or message.migrate_from_chat_id
|
||||
or message.pinned_message)
|
||||
|
||||
status_update = StatusUpdate()
|
||||
|
||||
class Forwarded(BaseFilter):
|
||||
|
||||
def __call__(self, message):
|
||||
return bool(message.forward_date)
|
||||
|
||||
forwarded = Forwarded()
|
|
@ -23,69 +23,6 @@ from telegram import Update
|
|||
from telegram.utils.deprecate import deprecate
|
||||
|
||||
|
||||
class Filters(object):
|
||||
"""
|
||||
Convenient namespace (class) & methods for the filter funcs of the
|
||||
MessageHandler class.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def text(message):
|
||||
return message.text and not message.text.startswith('/')
|
||||
|
||||
@staticmethod
|
||||
def command(message):
|
||||
return message.text and message.text.startswith('/')
|
||||
|
||||
@staticmethod
|
||||
def audio(message):
|
||||
return bool(message.audio)
|
||||
|
||||
@staticmethod
|
||||
def document(message):
|
||||
return bool(message.document)
|
||||
|
||||
@staticmethod
|
||||
def photo(message):
|
||||
return bool(message.photo)
|
||||
|
||||
@staticmethod
|
||||
def sticker(message):
|
||||
return bool(message.sticker)
|
||||
|
||||
@staticmethod
|
||||
def video(message):
|
||||
return bool(message.video)
|
||||
|
||||
@staticmethod
|
||||
def voice(message):
|
||||
return bool(message.voice)
|
||||
|
||||
@staticmethod
|
||||
def contact(message):
|
||||
return bool(message.contact)
|
||||
|
||||
@staticmethod
|
||||
def location(message):
|
||||
return bool(message.location)
|
||||
|
||||
@staticmethod
|
||||
def venue(message):
|
||||
return bool(message.venue)
|
||||
|
||||
@staticmethod
|
||||
def status_update(message):
|
||||
return bool(message.new_chat_member or message.left_chat_member or message.new_chat_title
|
||||
or message.new_chat_photo or message.delete_chat_photo
|
||||
or message.group_chat_created or message.supergroup_chat_created
|
||||
or message.channel_chat_created or message.migrate_to_chat_id
|
||||
or message.migrate_from_chat_id or message.pinned_message)
|
||||
|
||||
@staticmethod
|
||||
def forwarded(message):
|
||||
return bool(message.forward_date)
|
||||
|
||||
|
||||
class MessageHandler(Handler):
|
||||
"""
|
||||
Handler class to handle telegram messages. Messages are Telegram Updates
|
||||
|
|
|
@ -150,6 +150,30 @@ class FiltersTest(BaseTest, unittest.TestCase):
|
|||
self.assertTrue(Filters.status_update(self.message))
|
||||
self.message.pinned_message = None
|
||||
|
||||
def test_and_filters(self):
|
||||
# For now just test with forwarded as that's the only one that makes sense
|
||||
# That'll change when we get a entities filter
|
||||
self.message.text = 'test'
|
||||
self.message.forward_date = True
|
||||
self.assertTrue((Filters.text & Filters.forwarded)(self.message))
|
||||
self.message.text = '/test'
|
||||
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
|
||||
self.message.text = 'test'
|
||||
self.message.forward_date = None
|
||||
self.assertFalse((Filters.text & Filters.forwarded)(self.message))
|
||||
|
||||
def test_or_filters(self):
|
||||
# For now just test with forwarded as that's the only one that makes sense
|
||||
# That'll change when we get a entities filter
|
||||
self.message.text = 'test'
|
||||
self.assertTrue((Filters.text | Filters.status_update)(self.message))
|
||||
self.message.group_chat_created = True
|
||||
self.assertTrue((Filters.text | Filters.status_update)(self.message))
|
||||
self.message.text = None
|
||||
self.assertTrue((Filters.text | Filters.status_update)(self.message))
|
||||
self.message.group_chat_created = False
|
||||
self.assertFalse((Filters.text | Filters.status_update)(self.message))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Reference in a new issue