mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-19 15:43:24 +01:00
Merge pull request #675 from python-telegram-bot/name-filters
Allow filters to have a name.
This commit is contained in:
commit
3ea16cb1c7
2 changed files with 61 additions and 17 deletions
|
@ -51,8 +51,14 @@ class BaseFilter(object):
|
|||
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
|
||||
otherwise. Note that the filters work only as class instances, not actual class objects
|
||||
(so remember to initialize your filter classes).
|
||||
|
||||
By default the filters name (what will get printed when converted to a string for display)
|
||||
will be the class name. If you want to overwrite this assign a better name to the `name`
|
||||
class variable.
|
||||
"""
|
||||
|
||||
name = None
|
||||
|
||||
def __call__(self, message):
|
||||
return self.filter(message)
|
||||
|
||||
|
@ -65,6 +71,12 @@ class BaseFilter(object):
|
|||
def __invert__(self):
|
||||
return InvertedFilter(self)
|
||||
|
||||
def __repr__(self):
|
||||
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
|
||||
if self.name is None:
|
||||
self.name = self.__class__.__name__
|
||||
return self.name
|
||||
|
||||
def filter(self, message):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -82,10 +94,8 @@ class InvertedFilter(BaseFilter):
|
|||
def filter(self, message):
|
||||
return not self.f(message)
|
||||
|
||||
def __str__(self):
|
||||
return "<telegram.ext.filters.InvertedFilter inverting {}>".format(self.f)
|
||||
|
||||
__repr__ = __str__
|
||||
def __repr__(self):
|
||||
return "<inverted {}>".format(self.f)
|
||||
|
||||
|
||||
class MergedFilter(BaseFilter):
|
||||
|
@ -108,12 +118,9 @@ class MergedFilter(BaseFilter):
|
|||
elif self.or_filter:
|
||||
return self.base_filter(message) or self.or_filter(message)
|
||||
|
||||
def __str__(self):
|
||||
return ("<telegram.ext.filters.MergedFilter consisting of"
|
||||
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
|
||||
self.and_filter or self.or_filter)
|
||||
|
||||
__repr__ = __str__
|
||||
def __repr__(self):
|
||||
return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or",
|
||||
self.and_filter or self.or_filter)
|
||||
|
||||
|
||||
class Filters(object):
|
||||
|
@ -122,6 +129,7 @@ class Filters(object):
|
|||
"""
|
||||
|
||||
class _All(BaseFilter):
|
||||
name = 'Filters.all'
|
||||
|
||||
def filter(self, message):
|
||||
return True
|
||||
|
@ -129,6 +137,7 @@ class Filters(object):
|
|||
all = _All()
|
||||
|
||||
class _Text(BaseFilter):
|
||||
name = 'Filters.text'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.text and not message.text.startswith('/'))
|
||||
|
@ -136,20 +145,23 @@ class Filters(object):
|
|||
text = _Text()
|
||||
|
||||
class _Command(BaseFilter):
|
||||
name = 'Filters.command'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.text and message.text.startswith('/'))
|
||||
|
||||
command = _Command()
|
||||
|
||||
class _Reply(BaseFilter):
|
||||
name = 'Filters.reply'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.reply_to_message)
|
||||
|
||||
reply = _Reply()
|
||||
|
||||
command = _Command()
|
||||
|
||||
class _Audio(BaseFilter):
|
||||
name = 'Filters.audio'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.audio)
|
||||
|
@ -157,6 +169,7 @@ class Filters(object):
|
|||
audio = _Audio()
|
||||
|
||||
class _Document(BaseFilter):
|
||||
name = 'Filters.document'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.document)
|
||||
|
@ -164,6 +177,7 @@ class Filters(object):
|
|||
document = _Document()
|
||||
|
||||
class _Photo(BaseFilter):
|
||||
name = 'Filters.photo'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.photo)
|
||||
|
@ -171,6 +185,7 @@ class Filters(object):
|
|||
photo = _Photo()
|
||||
|
||||
class _Sticker(BaseFilter):
|
||||
name = 'Filters.sticker'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.sticker)
|
||||
|
@ -178,6 +193,7 @@ class Filters(object):
|
|||
sticker = _Sticker()
|
||||
|
||||
class _Video(BaseFilter):
|
||||
name = 'Filters.video'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.video)
|
||||
|
@ -185,6 +201,7 @@ class Filters(object):
|
|||
video = _Video()
|
||||
|
||||
class _Voice(BaseFilter):
|
||||
name = 'Filters.voice'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.voice)
|
||||
|
@ -192,6 +209,7 @@ class Filters(object):
|
|||
voice = _Voice()
|
||||
|
||||
class _Contact(BaseFilter):
|
||||
name = 'Filters.contact'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.contact)
|
||||
|
@ -199,6 +217,7 @@ class Filters(object):
|
|||
contact = _Contact()
|
||||
|
||||
class _Location(BaseFilter):
|
||||
name = 'Filters.location'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.location)
|
||||
|
@ -206,6 +225,7 @@ class Filters(object):
|
|||
location = _Location()
|
||||
|
||||
class _Venue(BaseFilter):
|
||||
name = 'Filters.venue'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.venue)
|
||||
|
@ -215,6 +235,7 @@ class Filters(object):
|
|||
class _StatusUpdate(BaseFilter):
|
||||
|
||||
class _NewChatMembers(BaseFilter):
|
||||
name = 'Filters.status_update.new_chat_members'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.new_chat_members)
|
||||
|
@ -222,6 +243,7 @@ class Filters(object):
|
|||
new_chat_members = _NewChatMembers()
|
||||
|
||||
class _LeftChatMember(BaseFilter):
|
||||
name = 'Filters.status_update.left_chat_member'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.left_chat_member)
|
||||
|
@ -229,6 +251,7 @@ class Filters(object):
|
|||
left_chat_member = _LeftChatMember()
|
||||
|
||||
class _NewChatTitle(BaseFilter):
|
||||
name = 'Filters.status_update.new_chat_title'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.new_chat_title)
|
||||
|
@ -236,6 +259,7 @@ class Filters(object):
|
|||
new_chat_title = _NewChatTitle()
|
||||
|
||||
class _NewChatPhoto(BaseFilter):
|
||||
name = 'Filters.status_update.new_chat_photo'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.new_chat_photo)
|
||||
|
@ -243,6 +267,7 @@ class Filters(object):
|
|||
new_chat_photo = _NewChatPhoto()
|
||||
|
||||
class _DeleteChatPhoto(BaseFilter):
|
||||
name = 'Filters.status_update.delete_chat_photo'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.delete_chat_photo)
|
||||
|
@ -250,6 +275,7 @@ class Filters(object):
|
|||
delete_chat_photo = _DeleteChatPhoto()
|
||||
|
||||
class _ChatCreated(BaseFilter):
|
||||
name = 'Filters.status_update.chat_created'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.group_chat_created or message.supergroup_chat_created or
|
||||
|
@ -258,6 +284,7 @@ class Filters(object):
|
|||
chat_created = _ChatCreated()
|
||||
|
||||
class _Migrate(BaseFilter):
|
||||
name = 'Filters.status_update.migrate'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)
|
||||
|
@ -265,12 +292,15 @@ class Filters(object):
|
|||
migrate = _Migrate()
|
||||
|
||||
class _PinnedMessage(BaseFilter):
|
||||
name = 'Filters.status_update.pinned_message'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.pinned_message)
|
||||
|
||||
pinned_message = _PinnedMessage()
|
||||
|
||||
name = 'Filters.status_update'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
|
||||
self.new_chat_title(message) or self.new_chat_photo(message) or
|
||||
|
@ -280,6 +310,7 @@ class Filters(object):
|
|||
status_update = _StatusUpdate()
|
||||
|
||||
class _Forwarded(BaseFilter):
|
||||
name = 'Filters.forwarded'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.forward_date)
|
||||
|
@ -287,6 +318,7 @@ class Filters(object):
|
|||
forwarded = _Forwarded()
|
||||
|
||||
class _Game(BaseFilter):
|
||||
name = 'Filters.game'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.game)
|
||||
|
@ -306,11 +338,13 @@ class Filters(object):
|
|||
|
||||
def __init__(self, entity_type):
|
||||
self.entity_type = entity_type
|
||||
self.name = 'Filters.entity({})'.format(self.entity_type)
|
||||
|
||||
def filter(self, message):
|
||||
return any([entity.type == self.entity_type for entity in message.entities])
|
||||
|
||||
class _Private(BaseFilter):
|
||||
name = 'Filters.private'
|
||||
|
||||
def filter(self, message):
|
||||
return message.chat.type == Chat.PRIVATE
|
||||
|
@ -318,6 +352,7 @@ class Filters(object):
|
|||
private = _Private()
|
||||
|
||||
class _Group(BaseFilter):
|
||||
name = 'Filters.group'
|
||||
|
||||
def filter(self, message):
|
||||
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]
|
||||
|
@ -325,6 +360,7 @@ class Filters(object):
|
|||
group = _Group()
|
||||
|
||||
class _Invoice(BaseFilter):
|
||||
name = 'Filters.invoice'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.invoice)
|
||||
|
@ -332,6 +368,7 @@ class Filters(object):
|
|||
invoice = _Invoice()
|
||||
|
||||
class _SuccessfulPayment(BaseFilter):
|
||||
name = 'Filters.successful_payment'
|
||||
|
||||
def filter(self, message):
|
||||
return bool(message.successful_payment)
|
||||
|
@ -354,6 +391,7 @@ class Filters(object):
|
|||
self.lang = [lang]
|
||||
else:
|
||||
self.lang = lang
|
||||
self.name = 'Filters.language({})'.format(self.lang)
|
||||
|
||||
def filter(self, message):
|
||||
return message.from_user.language_code and any(
|
||||
|
|
|
@ -254,12 +254,10 @@ class FiltersTest(BaseTest, unittest.TestCase):
|
|||
self.assertTrue((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION))
|
||||
)(self.message))
|
||||
|
||||
self.assertRegexpMatches(
|
||||
self.assertEqual(
|
||||
str((Filters.text & (Filters.forwarded | Filters.entity(MessageEntity.MENTION)))),
|
||||
r"<telegram.ext.filters.MergedFilter consisting of <telegram.ext.filters.(Filters.)?_"
|
||||
r"Text object at .*?> and <telegram.ext.filters.MergedFilter consisting of "
|
||||
r"<telegram.ext.filters.(Filters.)?_Forwarded object at .*?> or "
|
||||
r"<telegram.ext.filters.(Filters.)?entity object at .*?>>>")
|
||||
'<Filters.text and <Filters.forwarded or Filters.entity(mention)>>'
|
||||
)
|
||||
|
||||
def test_inverted_filters(self):
|
||||
self.message.text = '/test'
|
||||
|
@ -323,6 +321,14 @@ class FiltersTest(BaseTest, unittest.TestCase):
|
|||
self.message.from_user.language_code = 'da'
|
||||
self.assertTrue(f(self.message))
|
||||
|
||||
def test_custom_unnamed_filter(self):
|
||||
class Unnamed(BaseFilter):
|
||||
def filter(self, message):
|
||||
return True
|
||||
|
||||
unnamed = Unnamed()
|
||||
self.assertEqual(str(unnamed), Unnamed.__name__)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in a new issue