From b5196f00b23918f7fb6db784cef92ff8eda08afa Mon Sep 17 00:00:00 2001 From: Paul Larsen Date: Fri, 20 Apr 2018 12:24:40 +0100 Subject: [PATCH] Add a caption_entity filter for filtering caption entities (#1068) * Add a caption_entity filter for filtering caption entities * remove unneeded list comprehensions --- telegram/ext/filters.py | 23 ++++++++++++++++++++++- tests/test_filters.py | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 3539914eb..4b3ade78d 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -541,7 +541,28 @@ class Filters(object): self.name = 'Filters.entity({})'.format(self.entity_type) def filter(self, message): - return any([entity.type == self.entity_type for entity in message.entities]) + return any(entity.type == self.entity_type for entity in message.entities) + + class caption_entity(BaseFilter): + """ + Filters media messages to only allow those which have a :class:`telegram.MessageEntity` + where their `type` matches `entity_type`. + + Examples: + Example ``MessageHandler(Filters.caption_entity("hashtag"), callback_method)`` + + Args: + entity_type: Caption Entity type to check for. All types can be found as constants + in :class:`telegram.MessageEntity`. + + """ + + def __init__(self, entity_type): + self.entity_type = entity_type + self.name = 'Filters.caption_entity({})'.format(self.entity_type) + + def filter(self, message): + return any(entity.type == self.entity_type for entity in message.caption_entities) class _Private(BaseFilter): name = 'Filters.private' diff --git a/tests/test_filters.py b/tests/test_filters.py index f6be41045..fab1f53c0 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -310,6 +310,21 @@ class TestFilters(object): second = MessageEntity.de_json(second, None) message.entities = [message_entity, second] assert Filters.entity(message_entity.type)(message) + assert not Filters.caption_entity(message_entity.type)(message) + + def test_caption_entities_filter(self, message, message_entity): + message.caption_entities = [message_entity] + assert Filters.caption_entity(message_entity.type)(message) + + message.caption_entities = [] + assert not Filters.caption_entity(MessageEntity.MENTION)(message) + + second = message_entity.to_dict() + second['type'] = 'bold' + second = MessageEntity.de_json(second, None) + message.caption_entities = [message_entity, second] + assert Filters.caption_entity(message_entity.type)(message) + assert not Filters.entity(message_entity.type)(message) def test_private_filter(self, message): assert Filters.private(message)