diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 1942fa859..1cf51ea8e 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -53,6 +53,9 @@ class BaseFilter(object): (so remember to initialize your filter classes). """ + def __init__(self, *, name=None): + self.name = name + def __call__(self, message): return self.filter(message) @@ -65,6 +68,15 @@ class BaseFilter(object): def __invert__(self): return InvertedFilter(self) + def __str__(self): + # Do not rely on classes overwriting __init__ to set a name + # so we can keep backwards compatibility + if not hasattr(self, 'name') or self.name is None: + self.name = self.__class__.__name__ + return self.name + + __repr__ = __str__ + def filter(self, message): raise NotImplementedError @@ -77,13 +89,14 @@ class InvertedFilter(BaseFilter): """ def __init__(self, f): + super().__init__() self.f = f def filter(self, message): return not self.f(message) def __str__(self): - return "".format(self.f) + return "".format(self.f) __repr__ = __str__ @@ -98,6 +111,7 @@ class MergedFilter(BaseFilter): """ def __init__(self, base_filter, and_filter=None, or_filter=None): + super().__init__() self.base_filter = base_filter self.and_filter = and_filter self.or_filter = or_filter @@ -109,9 +123,8 @@ class MergedFilter(BaseFilter): return self.base_filter(message) or self.or_filter(message) def __str__(self): - return ("").format(self.base_filter, "and" if self.and_filter else "or", - self.and_filter or self.or_filter) + return "<{} {} {}>".format(self.base_filter, "and" if self.and_filter else "or", + self.and_filter or self.or_filter) __repr__ = __str__ @@ -126,91 +139,91 @@ class Filters(object): def filter(self, message): return True - all = _All() + all = _All(name='Filters.all') class _Text(BaseFilter): def filter(self, message): return bool(message.text and not message.text.startswith('/')) - text = _Text() + text = _Text(name='Filters.text') class _Command(BaseFilter): def filter(self, message): return bool(message.text and message.text.startswith('/')) + command = _Command(name='Filters.command') + class _Reply(BaseFilter): def filter(self, message): return bool(message.reply_to_message) - reply = _Reply() - - command = _Command() + reply = _Reply(name='Filters.reply') class _Audio(BaseFilter): def filter(self, message): return bool(message.audio) - audio = _Audio() + audio = _Audio(name='Filters.audio') class _Document(BaseFilter): def filter(self, message): return bool(message.document) - document = _Document() + document = _Document(name='Filters.document') class _Photo(BaseFilter): def filter(self, message): return bool(message.photo) - photo = _Photo() + photo = _Photo(name='Filters.photo') class _Sticker(BaseFilter): def filter(self, message): return bool(message.sticker) - sticker = _Sticker() + sticker = _Sticker(name='Filters.sticker') class _Video(BaseFilter): def filter(self, message): return bool(message.video) - video = _Video() + video = _Video(name='Filters.video') class _Voice(BaseFilter): def filter(self, message): return bool(message.voice) - voice = _Voice() + voice = _Voice(name='Filters.voice') class _Contact(BaseFilter): def filter(self, message): return bool(message.contact) - contact = _Contact() + contact = _Contact(name='Filters.contact') class _Location(BaseFilter): def filter(self, message): return bool(message.location) - location = _Location() + location = _Location(name='Filters.location') class _Venue(BaseFilter): def filter(self, message): return bool(message.venue) - venue = _Venue() + venue = _Venue(name='Filters.venue') class _StatusUpdate(BaseFilter): @@ -219,35 +232,35 @@ class Filters(object): def filter(self, message): return bool(message.new_chat_members) - new_chat_members = _NewChatMembers() + new_chat_members = _NewChatMembers(name='Filters.status_update.new_chat_members') class _LeftChatMember(BaseFilter): def filter(self, message): return bool(message.left_chat_member) - left_chat_member = _LeftChatMember() + left_chat_member = _LeftChatMember(name='Filters.status_update.left_chat_member') class _NewChatTitle(BaseFilter): def filter(self, message): return bool(message.new_chat_title) - new_chat_title = _NewChatTitle() + new_chat_title = _NewChatTitle(name='Filters.status_update.new_chat_title') class _NewChatPhoto(BaseFilter): def filter(self, message): return bool(message.new_chat_photo) - new_chat_photo = _NewChatPhoto() + new_chat_photo = _NewChatPhoto(name='Filters.status_update.new_chat_photo') class _DeleteChatPhoto(BaseFilter): def filter(self, message): return bool(message.delete_chat_photo) - delete_chat_photo = _DeleteChatPhoto() + delete_chat_photo = _DeleteChatPhoto(name='Filters.status_update.delete_chat_photo') class _ChatCreated(BaseFilter): @@ -255,21 +268,21 @@ class Filters(object): return bool(message.group_chat_created or message.supergroup_chat_created or message.channel_chat_created) - chat_created = _ChatCreated() + chat_created = _ChatCreated(name='Filters.status_update.chat_created') class _Migrate(BaseFilter): def filter(self, message): return bool(message.migrate_from_chat_id or message.migrate_to_chat_id) - migrate = _Migrate() + migrate = _Migrate(name='Filters.status_update.migrate') class _PinnedMessage(BaseFilter): def filter(self, message): return bool(message.pinned_message) - pinned_message = _PinnedMessage() + pinned_message = _PinnedMessage(name='Filters.status_update.pinned_message') def filter(self, message): return bool(self.new_chat_members(message) or self.left_chat_member(message) or @@ -277,21 +290,21 @@ class Filters(object): self.delete_chat_photo(message) or self.chat_created(message) or self.migrate(message) or self.pinned_message(message)) - status_update = _StatusUpdate() + status_update = _StatusUpdate(name='Filters.status_update') class _Forwarded(BaseFilter): def filter(self, message): return bool(message.forward_date) - forwarded = _Forwarded() + forwarded = _Forwarded(name='Filters.forwarded') class _Game(BaseFilter): def filter(self, message): return bool(message.game) - game = _Game() + game = _Game(name='Filters.game') class entity(BaseFilter): """Filters messages to only allow those which have a :class:`telegram.MessageEntity` @@ -306,6 +319,7 @@ class Filters(object): def __init__(self, entity_type): self.entity_type = entity_type + super().__init__(name='Filters.entity({})'.format(self.entity_type)) def filter(self, message): return any([entity.type == self.entity_type for entity in message.entities]) @@ -315,28 +329,28 @@ class Filters(object): def filter(self, message): return message.chat.type == Chat.PRIVATE - private = _Private() + private = _Private(name='Filters.private') class _Group(BaseFilter): def filter(self, message): return message.chat.type in [Chat.GROUP, Chat.SUPERGROUP] - group = _Group() + group = _Group(name='Filters.group') class _Invoice(BaseFilter): def filter(self, message): return bool(message.invoice) - invoice = _Invoice() + invoice = _Invoice(name='Filters.invoice') class _SuccessfulPayment(BaseFilter): def filter(self, message): return bool(message.successful_payment) - successful_payment = _SuccessfulPayment() + successful_payment = _SuccessfulPayment(name='Filters.successful_payment') class language(BaseFilter): """ @@ -354,6 +368,7 @@ class Filters(object): self.lang = [lang] else: self.lang = lang + super().__init__(name='Filters.language({})'.format(self.lang)) def filter(self, message): return message.from_user.language_code and any(