Add a caption_entity filter for filtering caption entities (#1068)

* Add a caption_entity filter for filtering caption entities

* remove unneeded list comprehensions
This commit is contained in:
Paul Larsen 2018-04-20 12:24:40 +01:00 committed by Jannes Höke
parent 39c679e519
commit b5196f00b2
2 changed files with 37 additions and 1 deletions

View file

@ -541,7 +541,28 @@ class Filters(object):
self.name = 'Filters.entity({})'.format(self.entity_type)
def filter(self, message):
return any([entity.type == self.entity_type for entity in message.entities])
return any(entity.type == self.entity_type for entity in message.entities)
class caption_entity(BaseFilter):
"""
Filters media messages to only allow those which have a :class:`telegram.MessageEntity`
where their `type` matches `entity_type`.
Examples:
Example ``MessageHandler(Filters.caption_entity("hashtag"), callback_method)``
Args:
entity_type: Caption Entity type to check for. All types can be found as constants
in :class:`telegram.MessageEntity`.
"""
def __init__(self, entity_type):
self.entity_type = entity_type
self.name = 'Filters.caption_entity({})'.format(self.entity_type)
def filter(self, message):
return any(entity.type == self.entity_type for entity in message.caption_entities)
class _Private(BaseFilter):
name = 'Filters.private'

View file

@ -310,6 +310,21 @@ class TestFilters(object):
second = MessageEntity.de_json(second, None)
message.entities = [message_entity, second]
assert Filters.entity(message_entity.type)(message)
assert not Filters.caption_entity(message_entity.type)(message)
def test_caption_entities_filter(self, message, message_entity):
message.caption_entities = [message_entity]
assert Filters.caption_entity(message_entity.type)(message)
message.caption_entities = []
assert not Filters.caption_entity(MessageEntity.MENTION)(message)
second = message_entity.to_dict()
second['type'] = 'bold'
second = MessageEntity.de_json(second, None)
message.caption_entities = [message_entity, second]
assert Filters.caption_entity(message_entity.type)(message)
assert not Filters.entity(message_entity.type)(message)
def test_private_filter(self, message):
assert Filters.private(message)