Add filters.Mention (#3941)

Co-authored-by: Javohir Elmurodov <elmurodovjavohir@gmail.com>
This commit is contained in:
Bibo-Joshi 2023-10-23 21:11:56 +02:00 committed by GitHub
parent 075f517458
commit 300ec920a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 1 deletions

View file

@ -66,6 +66,7 @@ __all__ = (
"LOCATION",
"Language",
"MessageFilter",
"Mention",
"PASSPORT_DATA",
"PHOTO",
"POLL",
@ -91,7 +92,6 @@ __all__ = (
"VOICE",
"ViaBot",
)
import mimetypes
import re
from abc import ABC, abstractmethod
@ -99,6 +99,7 @@ from typing import (
Collection,
Dict,
FrozenSet,
Iterable,
List,
Match,
NoReturn,
@ -1521,6 +1522,73 @@ LOCATION = _Location(name="filters.LOCATION")
"""Messages that contain :attr:`telegram.Message.location`."""
class Mention(MessageFilter):
"""Messages containing mentions of specified users or chats.
Examples:
.. code-block:: python
MessageHandler(filters.Mention("username"), callback)
MessageHandler(filters.Mention(["@username", 123456]), callback)
.. versionadded:: NEXT.VERSION
Args:
mentions (:obj:`int` | :obj:`str` | :class:`telegram.User` | Collection[:obj:`int` | \
:obj:`str` | :class:`telegram.User`]):
Specifies the users and chats to filter for. Messages that do not mention at least one
of the specified users or chats will not be handled. Leading ``'@'`` s in usernames
will be discarded.
"""
__slots__ = ("_mentions",)
def __init__(self, mentions: SCT[Union[int, str, TGUser]]):
super().__init__(name=f"filters.Mention({mentions})")
if isinstance(mentions, Iterable) and not isinstance(mentions, str):
self._mentions = {self._fix_mention_username(mention) for mention in mentions}
else:
self._mentions = {self._fix_mention_username(mentions)}
@staticmethod
def _fix_mention_username(mention: Union[int, str, TGUser]) -> Union[int, str, TGUser]:
if not isinstance(mention, str):
return mention
return mention.lstrip("@")
@classmethod
def _check_mention(cls, message: Message, mention: Union[int, str, TGUser]) -> bool:
if not message.entities:
return False
entity_texts = message.parse_entities(
types=[MessageEntity.MENTION, MessageEntity.TEXT_MENTION]
)
if isinstance(mention, TGUser):
return any(
mention.id == entity.user.id
or mention.username == entity.user.username
or mention.username == cls._fix_mention_username(entity_texts[entity])
for entity in message.entities
if entity.user
) or any(
mention.username == cls._fix_mention_username(entity_text)
for entity_text in entity_texts.values()
)
if isinstance(mention, int):
return bool(
any(mention == entity.user.id for entity in message.entities if entity.user)
)
return any(
mention == cls._fix_mention_username(entity_text)
for entity_text in entity_texts.values()
)
def filter(self, message: Message) -> bool:
return any(self._check_mention(message, mention) for mention in self._mentions)
class _PassportData(MessageFilter):
__slots__ = ()

View file

@ -2423,3 +2423,65 @@ class TestFilters:
),
)
assert filters.ATTACHMENT.check_update(up)
def test_filters_mention_no_entities(self, update):
update.message.text = "test"
assert not filters.Mention("@test").check_update(update)
assert not filters.Mention(123456).check_update(update)
assert not filters.Mention("123456").check_update(update)
assert not filters.Mention(User(1, "first_name", False)).check_update(update)
assert not filters.Mention(
["@test", 123456, "123456", User(1, "first_name", False)]
).check_update(update)
def test_filters_mention_type_mention(self, update):
update.message.text = "@test1 @test2 user"
update.message.entities = [
MessageEntity(MessageEntity.MENTION, 0, 6),
MessageEntity(MessageEntity.MENTION, 7, 6),
]
user_no_username = User(123456, "first_name", False)
user_wrong_username = User(123456, "first_name", False, username="wrong")
user_1 = User(111, "first_name", False, username="test1")
user_2 = User(222, "first_name", False, username="test2")
for username in ("@test1", "@test2"):
assert filters.Mention(username).check_update(update)
assert filters.Mention({username}).check_update(update)
for user in (user_1, user_2):
assert filters.Mention(user).check_update(update)
assert filters.Mention({user}).check_update(update)
assert not filters.Mention(
["@test3", 123, user_no_username, user_wrong_username]
).check_update(update)
def test_filters_mention_type_text_mention(self, update):
user_1 = User(111, "first_name", False, username="test1")
user_2 = User(222, "first_name", False, username="test2")
user_no_username = User(123456, "first_name", False)
user_wrong_username = User(123456, "first_name", False, username="wrong")
update.message.text = "test1 test2 user"
update.message.entities = [
MessageEntity(MessageEntity.TEXT_MENTION, 0, 5, user=user_1),
MessageEntity(MessageEntity.TEXT_MENTION, 6, 5, user=user_2),
]
for username in ("@test1", "@test2"):
assert filters.Mention(username).check_update(update)
assert filters.Mention({username}).check_update(update)
for user in (user_1, user_2):
assert filters.Mention(user).check_update(update)
assert filters.Mention({user}).check_update(update)
for user_id in (111, 222):
assert filters.Mention(user_id).check_update(update)
assert filters.Mention({user_id}).check_update(update)
assert not filters.Mention(
["@test3", 123, user_no_username, user_wrong_username]
).check_update(update)