diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index a8fef2e0a..0e15cd0cb 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -66,6 +66,7 @@ __all__ = ( "LOCATION", "Language", "MessageFilter", + "Mention", "PASSPORT_DATA", "PHOTO", "POLL", @@ -91,7 +92,6 @@ __all__ = ( "VOICE", "ViaBot", ) - import mimetypes import re from abc import ABC, abstractmethod @@ -99,6 +99,7 @@ from typing import ( Collection, Dict, FrozenSet, + Iterable, List, Match, NoReturn, @@ -1521,6 +1522,73 @@ LOCATION = _Location(name="filters.LOCATION") """Messages that contain :attr:`telegram.Message.location`.""" +class Mention(MessageFilter): + """Messages containing mentions of specified users or chats. + + Examples: + .. code-block:: python + + MessageHandler(filters.Mention("username"), callback) + MessageHandler(filters.Mention(["@username", 123456]), callback) + + .. versionadded:: NEXT.VERSION + + Args: + mentions (:obj:`int` | :obj:`str` | :class:`telegram.User` | Collection[:obj:`int` | \ + :obj:`str` | :class:`telegram.User`]): + Specifies the users and chats to filter for. Messages that do not mention at least one + of the specified users or chats will not be handled. Leading ``'@'`` s in usernames + will be discarded. + """ + + __slots__ = ("_mentions",) + + def __init__(self, mentions: SCT[Union[int, str, TGUser]]): + super().__init__(name=f"filters.Mention({mentions})") + if isinstance(mentions, Iterable) and not isinstance(mentions, str): + self._mentions = {self._fix_mention_username(mention) for mention in mentions} + else: + self._mentions = {self._fix_mention_username(mentions)} + + @staticmethod + def _fix_mention_username(mention: Union[int, str, TGUser]) -> Union[int, str, TGUser]: + if not isinstance(mention, str): + return mention + return mention.lstrip("@") + + @classmethod + def _check_mention(cls, message: Message, mention: Union[int, str, TGUser]) -> bool: + if not message.entities: + return False + + entity_texts = message.parse_entities( + types=[MessageEntity.MENTION, MessageEntity.TEXT_MENTION] + ) + + if isinstance(mention, TGUser): + return any( + mention.id == entity.user.id + or mention.username == entity.user.username + or mention.username == cls._fix_mention_username(entity_texts[entity]) + for entity in message.entities + if entity.user + ) or any( + mention.username == cls._fix_mention_username(entity_text) + for entity_text in entity_texts.values() + ) + if isinstance(mention, int): + return bool( + any(mention == entity.user.id for entity in message.entities if entity.user) + ) + return any( + mention == cls._fix_mention_username(entity_text) + for entity_text in entity_texts.values() + ) + + def filter(self, message: Message) -> bool: + return any(self._check_mention(message, mention) for mention in self._mentions) + + class _PassportData(MessageFilter): __slots__ = () diff --git a/tests/ext/test_filters.py b/tests/ext/test_filters.py index 760c41c56..f0812db66 100644 --- a/tests/ext/test_filters.py +++ b/tests/ext/test_filters.py @@ -2423,3 +2423,65 @@ class TestFilters: ), ) assert filters.ATTACHMENT.check_update(up) + + def test_filters_mention_no_entities(self, update): + update.message.text = "test" + assert not filters.Mention("@test").check_update(update) + assert not filters.Mention(123456).check_update(update) + assert not filters.Mention("123456").check_update(update) + assert not filters.Mention(User(1, "first_name", False)).check_update(update) + assert not filters.Mention( + ["@test", 123456, "123456", User(1, "first_name", False)] + ).check_update(update) + + def test_filters_mention_type_mention(self, update): + update.message.text = "@test1 @test2 user" + update.message.entities = [ + MessageEntity(MessageEntity.MENTION, 0, 6), + MessageEntity(MessageEntity.MENTION, 7, 6), + ] + + user_no_username = User(123456, "first_name", False) + user_wrong_username = User(123456, "first_name", False, username="wrong") + user_1 = User(111, "first_name", False, username="test1") + user_2 = User(222, "first_name", False, username="test2") + + for username in ("@test1", "@test2"): + assert filters.Mention(username).check_update(update) + assert filters.Mention({username}).check_update(update) + + for user in (user_1, user_2): + assert filters.Mention(user).check_update(update) + assert filters.Mention({user}).check_update(update) + + assert not filters.Mention( + ["@test3", 123, user_no_username, user_wrong_username] + ).check_update(update) + + def test_filters_mention_type_text_mention(self, update): + user_1 = User(111, "first_name", False, username="test1") + user_2 = User(222, "first_name", False, username="test2") + user_no_username = User(123456, "first_name", False) + user_wrong_username = User(123456, "first_name", False, username="wrong") + + update.message.text = "test1 test2 user" + update.message.entities = [ + MessageEntity(MessageEntity.TEXT_MENTION, 0, 5, user=user_1), + MessageEntity(MessageEntity.TEXT_MENTION, 6, 5, user=user_2), + ] + + for username in ("@test1", "@test2"): + assert filters.Mention(username).check_update(update) + assert filters.Mention({username}).check_update(update) + + for user in (user_1, user_2): + assert filters.Mention(user).check_update(update) + assert filters.Mention({user}).check_update(update) + + for user_id in (111, 222): + assert filters.Mention(user_id).check_update(update) + assert filters.Mention({user_id}).check_update(update) + + assert not filters.Mention( + ["@test3", 123, user_no_username, user_wrong_username] + ).check_update(update)