Allow filters to have a name.

So their string representation is user friendly.
This commit is contained in:
Jacob Bom 2017-06-18 15:28:48 +02:00
parent faddb92395
commit 4c41f3870c

View file

@ -53,6 +53,9 @@ class BaseFilter(object):
(so remember to initialize your filter classes).
"""
def __init__(self, *, name=None):
self.name = name
def __call__(self, message):
return self.filter(message)
@ -65,6 +68,15 @@ class BaseFilter(object):
def __invert__(self):
return InvertedFilter(self)
def __str__(self):
# Do not rely on classes overwriting __init__ to set a name
# so we can keep backwards compatibility
if not hasattr(self, 'name') or self.name is None:
self.name = self.__class__.__name__
return self.name
__repr__ = __str__
def filter(self, message):
raise NotImplementedError
@ -77,13 +89,14 @@ class InvertedFilter(BaseFilter):
"""
def __init__(self, f):
super().__init__()
self.f = f
def filter(self, message):
return not self.f(message)
def __str__(self):
return "<telegram.ext.filters.InvertedFilter inverting {}>".format(self.f)
return "<inverted {}>".format(self.f)
__repr__ = __str__
@ -98,6 +111,7 @@ class MergedFilter(BaseFilter):
"""
def __init__(self, base_filter, and_filter=None, or_filter=None):
super().__init__()
self.base_filter = base_filter
self.and_filter = and_filter
self.or_filter = or_filter
@ -109,9 +123,8 @@ class MergedFilter(BaseFilter):
return self.base_filter(message) or self.or_filter(message)
def __str__(self):
return ("<telegram.ext.filters.MergedFilter consisting of"
" {} {} {}>").format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)
return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or",
self.and_filter or self.or_filter)
__repr__ = __str__
@ -126,91 +139,91 @@ class Filters(object):
def filter(self, message):
return True
all = _All()
all = _All(name='Filters.all')
class _Text(BaseFilter):
def filter(self, message):
return bool(message.text and not message.text.startswith('/'))
text = _Text()
text = _Text(name='Filters.text')
class _Command(BaseFilter):
def filter(self, message):
return bool(message.text and message.text.startswith('/'))
command = _Command(name='Filters.command')
class _Reply(BaseFilter):
def filter(self, message):
return bool(message.reply_to_message)
reply = _Reply()
command = _Command()
reply = _Reply(name='Filters.reply')
class _Audio(BaseFilter):
def filter(self, message):
return bool(message.audio)
audio = _Audio()
audio = _Audio(name='Filters.audio')
class _Document(BaseFilter):
def filter(self, message):
return bool(message.document)
document = _Document()
document = _Document(name='Filters.document')
class _Photo(BaseFilter):
def filter(self, message):
return bool(message.photo)
photo = _Photo()
photo = _Photo(name='Filters.photo')
class _Sticker(BaseFilter):
def filter(self, message):
return bool(message.sticker)
sticker = _Sticker()
sticker = _Sticker(name='Filters.sticker')
class _Video(BaseFilter):
def filter(self, message):
return bool(message.video)
video = _Video()
video = _Video(name='Filters.video')
class _Voice(BaseFilter):
def filter(self, message):
return bool(message.voice)
voice = _Voice()
voice = _Voice(name='Filters.voice')
class _Contact(BaseFilter):
def filter(self, message):
return bool(message.contact)
contact = _Contact()
contact = _Contact(name='Filters.contact')
class _Location(BaseFilter):
def filter(self, message):
return bool(message.location)
location = _Location()
location = _Location(name='Filters.location')
class _Venue(BaseFilter):
def filter(self, message):
return bool(message.venue)
venue = _Venue()
venue = _Venue(name='Filters.venue')
class _StatusUpdate(BaseFilter):
@ -219,35 +232,35 @@ class Filters(object):
def filter(self, message):
return bool(message.new_chat_members)
new_chat_members = _NewChatMembers()
new_chat_members = _NewChatMembers(name='Filters.status_update.new_chat_members')
class _LeftChatMember(BaseFilter):
def filter(self, message):
return bool(message.left_chat_member)
left_chat_member = _LeftChatMember()
left_chat_member = _LeftChatMember(name='Filters.status_update.left_chat_member')
class _NewChatTitle(BaseFilter):
def filter(self, message):
return bool(message.new_chat_title)
new_chat_title = _NewChatTitle()
new_chat_title = _NewChatTitle(name='Filters.status_update.new_chat_title')
class _NewChatPhoto(BaseFilter):
def filter(self, message):
return bool(message.new_chat_photo)
new_chat_photo = _NewChatPhoto()
new_chat_photo = _NewChatPhoto(name='Filters.status_update.new_chat_photo')
class _DeleteChatPhoto(BaseFilter):
def filter(self, message):
return bool(message.delete_chat_photo)
delete_chat_photo = _DeleteChatPhoto()
delete_chat_photo = _DeleteChatPhoto(name='Filters.status_update.delete_chat_photo')
class _ChatCreated(BaseFilter):
@ -255,21 +268,21 @@ class Filters(object):
return bool(message.group_chat_created or message.supergroup_chat_created or
message.channel_chat_created)
chat_created = _ChatCreated()
chat_created = _ChatCreated(name='Filters.status_update.chat_created')
class _Migrate(BaseFilter):
def filter(self, message):
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)
migrate = _Migrate()
migrate = _Migrate(name='Filters.status_update.migrate')
class _PinnedMessage(BaseFilter):
def filter(self, message):
return bool(message.pinned_message)
pinned_message = _PinnedMessage()
pinned_message = _PinnedMessage(name='Filters.status_update.pinned_message')
def filter(self, message):
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
@ -277,21 +290,21 @@ class Filters(object):
self.delete_chat_photo(message) or self.chat_created(message) or
self.migrate(message) or self.pinned_message(message))
status_update = _StatusUpdate()
status_update = _StatusUpdate(name='Filters.status_update')
class _Forwarded(BaseFilter):
def filter(self, message):
return bool(message.forward_date)
forwarded = _Forwarded()
forwarded = _Forwarded(name='Filters.forwarded')
class _Game(BaseFilter):
def filter(self, message):
return bool(message.game)
game = _Game()
game = _Game(name='Filters.game')
class entity(BaseFilter):
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
@ -306,6 +319,7 @@ class Filters(object):
def __init__(self, entity_type):
self.entity_type = entity_type
super().__init__(name='Filters.entity({})'.format(self.entity_type))
def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])
@ -315,28 +329,28 @@ class Filters(object):
def filter(self, message):
return message.chat.type == Chat.PRIVATE
private = _Private()
private = _Private(name='Filters.private')
class _Group(BaseFilter):
def filter(self, message):
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]
group = _Group()
group = _Group(name='Filters.group')
class _Invoice(BaseFilter):
def filter(self, message):
return bool(message.invoice)
invoice = _Invoice()
invoice = _Invoice(name='Filters.invoice')
class _SuccessfulPayment(BaseFilter):
def filter(self, message):
return bool(message.successful_payment)
successful_payment = _SuccessfulPayment()
successful_payment = _SuccessfulPayment(name='Filters.successful_payment')
class language(BaseFilter):
"""
@ -354,6 +368,7 @@ class Filters(object):
self.lang = [lang]
else:
self.lang = lang
super().__init__(name='Filters.language({})'.format(self.lang))
def filter(self, message):
return message.from_user.language_code and any(