Merge pull request #675 from python-telegram-bot/name-filters

Allow filters to have a name.
This commit is contained in:
Noam Meltzer 2017-06-21 23:11:26 +03:00 committed by GitHub
commit 3ea16cb1c7
2 changed files with 61 additions and 17 deletions

View file

@ -51,8 +51,14 @@ class BaseFilter(object):
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
otherwise. Note that the filters work only as class instances, not actual class objects
(so remember to initialize your filter classes).
By default the filters name (what will get printed when converted to a string for display)
will be the class name. If you want to overwrite this assign a better name to the `name`
class variable.
"""
name = None
def __call__(self, message):
return self.filter(message)
@ -65,6 +71,12 @@ class BaseFilter(object):
def __invert__(self):
return InvertedFilter(self)
def __repr__(self):
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
if self.name is None:
self.name = self.__class__.__name__
return self.name
def filter(self, message):
raise NotImplementedError
@ -82,10 +94,8 @@ class InvertedFilter(BaseFilter):
def filter(self, message):
return not self.f(message)
def __str__(self):
return "<telegram.ext.filters.InvertedFilter inverting {}>".format(self.f)
__repr__ = __str__
def __repr__(self):
return "<inverted {}>".format(self.f)
class MergedFilter(BaseFilter):
@ -108,12 +118,9 @@ class MergedFilter(BaseFilter):
elif self.or_filter:
return self.base_filter(message) or self.or_filter(message)
def __str__(self):
return ("<telegram.ext.filters.MergedFilter consisting of"
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)
__repr__ = __str__
def __repr__(self):
return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)
class Filters(object):
@ -122,6 +129,7 @@ class Filters(object):
"""
class _All(BaseFilter):
name = 'Filters.all'
def filter(self, message):
return True
@ -129,6 +137,7 @@ class Filters(object):
all = _All()
class _Text(BaseFilter):
name = 'Filters.text'
def filter(self, message):
return bool(message.text and not message.text.startswith('/'))
@ -136,20 +145,23 @@ class Filters(object):
text = _Text()
class _Command(BaseFilter):
name = 'Filters.command'
def filter(self, message):
return bool(message.text and message.text.startswith('/'))
command = _Command()
class _Reply(BaseFilter):
name = 'Filters.reply'
def filter(self, message):
return bool(message.reply_to_message)
reply = _Reply()
command = _Command()
class _Audio(BaseFilter):
name = 'Filters.audio'
def filter(self, message):
return bool(message.audio)
@ -157,6 +169,7 @@ class Filters(object):
audio = _Audio()
class _Document(BaseFilter):
name = 'Filters.document'
def filter(self, message):
return bool(message.document)
@ -164,6 +177,7 @@ class Filters(object):
document = _Document()
class _Photo(BaseFilter):
name = 'Filters.photo'
def filter(self, message):
return bool(message.photo)
@ -171,6 +185,7 @@ class Filters(object):
photo = _Photo()
class _Sticker(BaseFilter):
name = 'Filters.sticker'
def filter(self, message):
return bool(message.sticker)
@ -178,6 +193,7 @@ class Filters(object):
sticker = _Sticker()
class _Video(BaseFilter):
name = 'Filters.video'
def filter(self, message):
return bool(message.video)
@ -185,6 +201,7 @@ class Filters(object):
video = _Video()
class _Voice(BaseFilter):
name = 'Filters.voice'
def filter(self, message):
return bool(message.voice)
@ -192,6 +209,7 @@ class Filters(object):
voice = _Voice()
class _Contact(BaseFilter):
name = 'Filters.contact'
def filter(self, message):
return bool(message.contact)
@ -199,6 +217,7 @@ class Filters(object):
contact = _Contact()
class _Location(BaseFilter):
name = 'Filters.location'
def filter(self, message):
return bool(message.location)
@ -206,6 +225,7 @@ class Filters(object):
location = _Location()
class _Venue(BaseFilter):
name = 'Filters.venue'
def filter(self, message):
return bool(message.venue)
@ -215,6 +235,7 @@ class Filters(object):
class _StatusUpdate(BaseFilter):
class _NewChatMembers(BaseFilter):
name = 'Filters.status_update.new_chat_members'
def filter(self, message):
return bool(message.new_chat_members)
@ -222,6 +243,7 @@ class Filters(object):
new_chat_members = _NewChatMembers()
class _LeftChatMember(BaseFilter):
name = 'Filters.status_update.left_chat_member'
def filter(self, message):
return bool(message.left_chat_member)
@ -229,6 +251,7 @@ class Filters(object):
left_chat_member = _LeftChatMember()
class _NewChatTitle(BaseFilter):
name = 'Filters.status_update.new_chat_title'
def filter(self, message):
return bool(message.new_chat_title)
@ -236,6 +259,7 @@ class Filters(object):
new_chat_title = _NewChatTitle()
class _NewChatPhoto(BaseFilter):
name = 'Filters.status_update.new_chat_photo'
def filter(self, message):
return bool(message.new_chat_photo)
@ -243,6 +267,7 @@ class Filters(object):
new_chat_photo = _NewChatPhoto()
class _DeleteChatPhoto(BaseFilter):
name = 'Filters.status_update.delete_chat_photo'
def filter(self, message):
return bool(message.delete_chat_photo)
@ -250,6 +275,7 @@ class Filters(object):
delete_chat_photo = _DeleteChatPhoto()
class _ChatCreated(BaseFilter):
name = 'Filters.status_update.chat_created'
def filter(self, message):
return bool(message.group_chat_created or message.supergroup_chat_created or
@ -258,6 +284,7 @@ class Filters(object):
chat_created = _ChatCreated()
class _Migrate(BaseFilter):
name = 'Filters.status_update.migrate'
def filter(self, message):
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)
@ -265,12 +292,15 @@ class Filters(object):
migrate = _Migrate()
class _PinnedMessage(BaseFilter):
name = 'Filters.status_update.pinned_message'
def filter(self, message):
return bool(message.pinned_message)
pinned_message = _PinnedMessage()
name = 'Filters.status_update'
def filter(self, message):
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
self.new_chat_title(message) or self.new_chat_photo(message) or
@ -280,6 +310,7 @@ class Filters(object):
status_update = _StatusUpdate()
class _Forwarded(BaseFilter):
name = 'Filters.forwarded'
def filter(self, message):
return bool(message.forward_date)
@ -287,6 +318,7 @@ class Filters(object):
forwarded = _Forwarded()
class _Game(BaseFilter):
name = 'Filters.game'
def filter(self, message):
return bool(message.game)
@ -306,11 +338,13 @@ class Filters(object):
def __init__(self, entity_type):
self.entity_type = entity_type
self.name = 'Filters.entity({})'.format(self.entity_type)
def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])
class _Private(BaseFilter):
name = 'Filters.private'
def filter(self, message):
return message.chat.type == Chat.PRIVATE
@ -318,6 +352,7 @@ class Filters(object):
private = _Private()
class _Group(BaseFilter):
name = 'Filters.group'
def filter(self, message):
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]
@ -325,6 +360,7 @@ class Filters(object):
group = _Group()
class _Invoice(BaseFilter):
name = 'Filters.invoice'
def filter(self, message):
return bool(message.invoice)
@ -332,6 +368,7 @@ class Filters(object):
invoice = _Invoice()
class _SuccessfulPayment(BaseFilter):
name = 'Filters.successful_payment'
def filter(self, message):
return bool(message.successful_payment)
@ -354,6 +391,7 @@ class Filters(object):
self.lang = [lang]
else:
self.lang = lang
self.name = 'Filters.language({})'.format(self.lang)
def filter(self, message):
return message.from_user.language_code and any(

View file

@ -254,12 +254,10 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
)(self.message))
self.assertRegexpMatches(
self.assertEqual(
str((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)))),
r"<telegram.ext.filters.MergedFilter consisting of <telegram.ext.filters.(Filters.)?_"
r"Text object at .*?> and <telegram.ext.filters.MergedFilter consisting of "
r"<telegram.ext.filters.(Filters.)?_Forwarded object at .*?> or "
r"<telegram.ext.filters.(Filters.)?entity object at .*?>>>")
'<Filters.text and <Filters.forwarded or Filters.entity(mention)>>'
)
def test_inverted_filters(self):
self.message.text = '/test'
@ -323,6 +321,14 @@ class FiltersTest(BaseTest, unittest.TestCase):
self.message.from_user.language_code = 'da'
self.assertTrue(f(self.message))
def test_custom_unnamed_filter(self):
class Unnamed(BaseFilter):
def filter(self, message):
return True
unnamed = Unnamed()
self.assertEqual(str(unnamed), Unnamed.__name__)
if __name__ == '__main__':
unittest.main()