mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 14:35:00 +01:00
Add filters.Mention
(#3941)
Co-authored-by: Javohir Elmurodov <elmurodovjavohir@gmail.com>
This commit is contained in:
parent
075f517458
commit
300ec920a1
2 changed files with 131 additions and 1 deletions
|
@ -66,6 +66,7 @@ __all__ = (
|
|||
"LOCATION",
|
||||
"Language",
|
||||
"MessageFilter",
|
||||
"Mention",
|
||||
"PASSPORT_DATA",
|
||||
"PHOTO",
|
||||
"POLL",
|
||||
|
@ -91,7 +92,6 @@ __all__ = (
|
|||
"VOICE",
|
||||
"ViaBot",
|
||||
)
|
||||
|
||||
import mimetypes
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -99,6 +99,7 @@ from typing import (
|
|||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Match,
|
||||
NoReturn,
|
||||
|
@ -1521,6 +1522,73 @@ LOCATION = _Location(name="filters.LOCATION")
|
|||
"""Messages that contain :attr:`telegram.Message.location`."""
|
||||
|
||||
|
||||
class Mention(MessageFilter):
|
||||
"""Messages containing mentions of specified users or chats.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
MessageHandler(filters.Mention("username"), callback)
|
||||
MessageHandler(filters.Mention(["@username", 123456]), callback)
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
|
||||
Args:
|
||||
mentions (:obj:`int` | :obj:`str` | :class:`telegram.User` | Collection[:obj:`int` | \
|
||||
:obj:`str` | :class:`telegram.User`]):
|
||||
Specifies the users and chats to filter for. Messages that do not mention at least one
|
||||
of the specified users or chats will not be handled. Leading ``'@'`` s in usernames
|
||||
will be discarded.
|
||||
"""
|
||||
|
||||
__slots__ = ("_mentions",)
|
||||
|
||||
def __init__(self, mentions: SCT[Union[int, str, TGUser]]):
|
||||
super().__init__(name=f"filters.Mention({mentions})")
|
||||
if isinstance(mentions, Iterable) and not isinstance(mentions, str):
|
||||
self._mentions = {self._fix_mention_username(mention) for mention in mentions}
|
||||
else:
|
||||
self._mentions = {self._fix_mention_username(mentions)}
|
||||
|
||||
@staticmethod
|
||||
def _fix_mention_username(mention: Union[int, str, TGUser]) -> Union[int, str, TGUser]:
|
||||
if not isinstance(mention, str):
|
||||
return mention
|
||||
return mention.lstrip("@")
|
||||
|
||||
@classmethod
|
||||
def _check_mention(cls, message: Message, mention: Union[int, str, TGUser]) -> bool:
|
||||
if not message.entities:
|
||||
return False
|
||||
|
||||
entity_texts = message.parse_entities(
|
||||
types=[MessageEntity.MENTION, MessageEntity.TEXT_MENTION]
|
||||
)
|
||||
|
||||
if isinstance(mention, TGUser):
|
||||
return any(
|
||||
mention.id == entity.user.id
|
||||
or mention.username == entity.user.username
|
||||
or mention.username == cls._fix_mention_username(entity_texts[entity])
|
||||
for entity in message.entities
|
||||
if entity.user
|
||||
) or any(
|
||||
mention.username == cls._fix_mention_username(entity_text)
|
||||
for entity_text in entity_texts.values()
|
||||
)
|
||||
if isinstance(mention, int):
|
||||
return bool(
|
||||
any(mention == entity.user.id for entity in message.entities if entity.user)
|
||||
)
|
||||
return any(
|
||||
mention == cls._fix_mention_username(entity_text)
|
||||
for entity_text in entity_texts.values()
|
||||
)
|
||||
|
||||
def filter(self, message: Message) -> bool:
|
||||
return any(self._check_mention(message, mention) for mention in self._mentions)
|
||||
|
||||
|
||||
class _PassportData(MessageFilter):
|
||||
__slots__ = ()
|
||||
|
||||
|
|
|
@ -2423,3 +2423,65 @@ class TestFilters:
|
|||
),
|
||||
)
|
||||
assert filters.ATTACHMENT.check_update(up)
|
||||
|
||||
def test_filters_mention_no_entities(self, update):
|
||||
update.message.text = "test"
|
||||
assert not filters.Mention("@test").check_update(update)
|
||||
assert not filters.Mention(123456).check_update(update)
|
||||
assert not filters.Mention("123456").check_update(update)
|
||||
assert not filters.Mention(User(1, "first_name", False)).check_update(update)
|
||||
assert not filters.Mention(
|
||||
["@test", 123456, "123456", User(1, "first_name", False)]
|
||||
).check_update(update)
|
||||
|
||||
def test_filters_mention_type_mention(self, update):
|
||||
update.message.text = "@test1 @test2 user"
|
||||
update.message.entities = [
|
||||
MessageEntity(MessageEntity.MENTION, 0, 6),
|
||||
MessageEntity(MessageEntity.MENTION, 7, 6),
|
||||
]
|
||||
|
||||
user_no_username = User(123456, "first_name", False)
|
||||
user_wrong_username = User(123456, "first_name", False, username="wrong")
|
||||
user_1 = User(111, "first_name", False, username="test1")
|
||||
user_2 = User(222, "first_name", False, username="test2")
|
||||
|
||||
for username in ("@test1", "@test2"):
|
||||
assert filters.Mention(username).check_update(update)
|
||||
assert filters.Mention({username}).check_update(update)
|
||||
|
||||
for user in (user_1, user_2):
|
||||
assert filters.Mention(user).check_update(update)
|
||||
assert filters.Mention({user}).check_update(update)
|
||||
|
||||
assert not filters.Mention(
|
||||
["@test3", 123, user_no_username, user_wrong_username]
|
||||
).check_update(update)
|
||||
|
||||
def test_filters_mention_type_text_mention(self, update):
|
||||
user_1 = User(111, "first_name", False, username="test1")
|
||||
user_2 = User(222, "first_name", False, username="test2")
|
||||
user_no_username = User(123456, "first_name", False)
|
||||
user_wrong_username = User(123456, "first_name", False, username="wrong")
|
||||
|
||||
update.message.text = "test1 test2 user"
|
||||
update.message.entities = [
|
||||
MessageEntity(MessageEntity.TEXT_MENTION, 0, 5, user=user_1),
|
||||
MessageEntity(MessageEntity.TEXT_MENTION, 6, 5, user=user_2),
|
||||
]
|
||||
|
||||
for username in ("@test1", "@test2"):
|
||||
assert filters.Mention(username).check_update(update)
|
||||
assert filters.Mention({username}).check_update(update)
|
||||
|
||||
for user in (user_1, user_2):
|
||||
assert filters.Mention(user).check_update(update)
|
||||
assert filters.Mention({user}).check_update(update)
|
||||
|
||||
for user_id in (111, 222):
|
||||
assert filters.Mention(user_id).check_update(update)
|
||||
assert filters.Mention({user_id}).check_update(update)
|
||||
|
||||
assert not filters.Mention(
|
||||
["@test3", 123, user_no_username, user_wrong_username]
|
||||
).check_update(update)
|
||||
|
|
Loading…
Reference in a new issue