Remove the need for calling super() in filters

This commit is contained in:
Jacob Bom 2017-06-21 13:46:03 +02:00
parent 6cc84b2c32
commit 04acbc4117

View file

@ -51,10 +51,13 @@ class BaseFilter(object):
a `filter` method that returns a boolean: `True` if the message should be handled, `False`
otherwise. Note that the filters work only as class instances, not actual class objects
(so remember to initialize your filter classes).
By default the filters name (what will get printed when converted to a string for display)
will be the class name. If you want to overwrite this assign a better name to the `name`
class variable.
"""
def __init__(self, name=None):
self.name = name
name = None
def __call__(self, message):
return self.filter(message)
@ -69,9 +72,8 @@ class BaseFilter(object):
return InvertedFilter(self)
def __repr__(self):
# Do not rely on classes overwriting __init__ to set a name
# so we can keep backwards compatibility
if not hasattr(self, 'name') or self.name is None:
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
if self.name is None:
self.name = self.__class__.__name__
return self.name
@ -87,7 +89,6 @@ class InvertedFilter(BaseFilter):
"""
def __init__(self, f):
super(InvertedFilter, self).__init__()
self.f = f
def filter(self, message):
@ -107,7 +108,6 @@ class MergedFilter(BaseFilter):
"""
def __init__(self, base_filter, and_filter=None, or_filter=None):
super(MergedFilter, self).__init__()
self.base_filter = base_filter
self.and_filter = and_filter
self.or_filter = or_filter
@ -129,154 +129,177 @@ class Filters(object):
"""
class _All(BaseFilter):
name = 'Filters.all'
def filter(self, message):
return True
all = _All(name='Filters.all')
all = _All()
class _Text(BaseFilter):
name = 'Filters.text'
def filter(self, message):
return bool(message.text and not message.text.startswith('/'))
text = _Text(name='Filters.text')
text = _Text()
class _Command(BaseFilter):
name = 'Filters.command'
def filter(self, message):
return bool(message.text and message.text.startswith('/'))
command = _Command(name='Filters.command')
command = _Command()
class _Reply(BaseFilter):
name = 'Filters.reply'
def filter(self, message):
return bool(message.reply_to_message)
reply = _Reply(name='Filters.reply')
reply = _Reply()
class _Audio(BaseFilter):
name = 'Filters.audio'
def filter(self, message):
return bool(message.audio)
audio = _Audio(name='Filters.audio')
audio = _Audio()
class _Document(BaseFilter):
name = 'Filters.document'
def filter(self, message):
return bool(message.document)
document = _Document(name='Filters.document')
document = _Document()
class _Photo(BaseFilter):
name = 'Filters.photo'
def filter(self, message):
return bool(message.photo)
photo = _Photo(name='Filters.photo')
photo = _Photo()
class _Sticker(BaseFilter):
name = 'Filters.sticker'
def filter(self, message):
return bool(message.sticker)
sticker = _Sticker(name='Filters.sticker')
sticker = _Sticker()
class _Video(BaseFilter):
name = 'Filters.video'
def filter(self, message):
return bool(message.video)
video = _Video(name='Filters.video')
video = _Video()
class _Voice(BaseFilter):
name = 'Filters.voice'
def filter(self, message):
return bool(message.voice)
voice = _Voice(name='Filters.voice')
voice = _Voice()
class _Contact(BaseFilter):
name = 'Filters.contact'
def filter(self, message):
return bool(message.contact)
contact = _Contact(name='Filters.contact')
contact = _Contact()
class _Location(BaseFilter):
name = 'Filters.location'
def filter(self, message):
return bool(message.location)
location = _Location(name='Filters.location')
location = _Location()
class _Venue(BaseFilter):
name = 'Filters.venue'
def filter(self, message):
return bool(message.venue)
venue = _Venue(name='Filters.venue')
venue = _Venue()
class _StatusUpdate(BaseFilter):
class _NewChatMembers(BaseFilter):
name = 'Filters.status_update.new_chat_members'
def filter(self, message):
return bool(message.new_chat_members)
new_chat_members = _NewChatMembers(name='Filters.status_update.new_chat_members')
new_chat_members = _NewChatMembers()
class _LeftChatMember(BaseFilter):
name = 'Filters.status_update.left_chat_member'
def filter(self, message):
return bool(message.left_chat_member)
left_chat_member = _LeftChatMember(name='Filters.status_update.left_chat_member')
left_chat_member = _LeftChatMember()
class _NewChatTitle(BaseFilter):
name = 'Filters.status_update.new_chat_title'
def filter(self, message):
return bool(message.new_chat_title)
new_chat_title = _NewChatTitle(name='Filters.status_update.new_chat_title')
new_chat_title = _NewChatTitle()
class _NewChatPhoto(BaseFilter):
name = 'Filters.status_update.new_chat_photo'
def filter(self, message):
return bool(message.new_chat_photo)
new_chat_photo = _NewChatPhoto(name='Filters.status_update.new_chat_photo')
new_chat_photo = _NewChatPhoto()
class _DeleteChatPhoto(BaseFilter):
name = 'Filters.status_update.delete_chat_photo'
def filter(self, message):
return bool(message.delete_chat_photo)
delete_chat_photo = _DeleteChatPhoto(name='Filters.status_update.delete_chat_photo')
delete_chat_photo = _DeleteChatPhoto()
class _ChatCreated(BaseFilter):
name = 'Filters.status_update.chat_created'
def filter(self, message):
return bool(message.group_chat_created or message.supergroup_chat_created or
message.channel_chat_created)
chat_created = _ChatCreated(name='Filters.status_update.chat_created')
chat_created = _ChatCreated()
class _Migrate(BaseFilter):
name = 'Filters.status_update.migrate'
def filter(self, message):
return bool(message.migrate_from_chat_id or message.migrate_to_chat_id)
migrate = _Migrate(name='Filters.status_update.migrate')
migrate = _Migrate()
class _PinnedMessage(BaseFilter):
name = 'Filters.status_update.pinned_message'
def filter(self, message):
return bool(message.pinned_message)
pinned_message = _PinnedMessage(name='Filters.status_update.pinned_message')
pinned_message = _PinnedMessage()
name = 'Filters.status_update'
def filter(self, message):
return bool(self.new_chat_members(message) or self.left_chat_member(message) or
@ -284,21 +307,23 @@ class Filters(object):
self.delete_chat_photo(message) or self.chat_created(message) or
self.migrate(message) or self.pinned_message(message))
status_update = _StatusUpdate(name='Filters.status_update')
status_update = _StatusUpdate()
class _Forwarded(BaseFilter):
name = 'Filters.forwarded'
def filter(self, message):
return bool(message.forward_date)
forwarded = _Forwarded(name='Filters.forwarded')
forwarded = _Forwarded()
class _Game(BaseFilter):
name = 'Filters.game'
def filter(self, message):
return bool(message.game)
game = _Game(name='Filters.game')
game = _Game()
class entity(BaseFilter):
"""Filters messages to only allow those which have a :class:`telegram.MessageEntity`
@ -313,39 +338,42 @@ class Filters(object):
def __init__(self, entity_type):
self.entity_type = entity_type
super(Filters.entity, self).__init__(name='Filters.entity({})'.format(
self.entity_type))
self.name = 'Filters.entity({})'.format(self.entity_type)
def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])
class _Private(BaseFilter):
name = 'Filters.private'
def filter(self, message):
return message.chat.type == Chat.PRIVATE
private = _Private(name='Filters.private')
private = _Private()
class _Group(BaseFilter):
name = 'Filters.group'
def filter(self, message):
return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP]
group = _Group(name='Filters.group')
group = _Group()
class _Invoice(BaseFilter):
name = 'Filters.invoice'
def filter(self, message):
return bool(message.invoice)
invoice = _Invoice(name='Filters.invoice')
invoice = _Invoice()
class _SuccessfulPayment(BaseFilter):
name = 'Filters.successful_payment'
def filter(self, message):
return bool(message.successful_payment)
successful_payment = _SuccessfulPayment(name='Filters.successful_payment')
successful_payment = _SuccessfulPayment()
class language(BaseFilter):
"""
@ -363,7 +391,7 @@ class Filters(object):
self.lang = [lang]
else:
self.lang = lang
super(Filters.language, self).__init__(name='Filters.language({})'.format(self.lang))
self.name = 'Filters.language({})'.format(self.lang)
def filter(self, message):
return message.from_user.language_code and any(