mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-03-26 08:32:58 +01:00
Add XOR Filters and make Filters.name a Property (#2179)
* XOR Filters and make Filters.name a property * add XORFilter to __all__ * Change example
This commit is contained in:
parent
27b03edc59
commit
d1438a9b23
2 changed files with 208 additions and 4 deletions
|
@ -24,7 +24,19 @@ import warnings
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict, FrozenSet, List, Match, Optional, Pattern, Set, Tuple, Union, cast
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
Match,
|
||||||
|
Optional,
|
||||||
|
Pattern,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
NoReturn,
|
||||||
|
)
|
||||||
|
|
||||||
from telegram import Chat, Message, MessageEntity, Update
|
from telegram import Chat, Message, MessageEntity, Update
|
||||||
|
|
||||||
|
@ -35,6 +47,7 @@ __all__ = [
|
||||||
'UpdateFilter',
|
'UpdateFilter',
|
||||||
'InvertedFilter',
|
'InvertedFilter',
|
||||||
'MergedFilter',
|
'MergedFilter',
|
||||||
|
'XORFilter',
|
||||||
]
|
]
|
||||||
|
|
||||||
from telegram.utils.deprecate import TelegramDeprecationWarning
|
from telegram.utils.deprecate import TelegramDeprecationWarning
|
||||||
|
@ -54,6 +67,10 @@ class BaseFilter(ABC):
|
||||||
|
|
||||||
>>> (Filters.audio | Filters.video)
|
>>> (Filters.audio | Filters.video)
|
||||||
|
|
||||||
|
Exclusive Or:
|
||||||
|
|
||||||
|
>>> (Filters.regex('To Be') ^ Filters.regex('Not To Be'))
|
||||||
|
|
||||||
Not:
|
Not:
|
||||||
|
|
||||||
>>> ~ Filters.command
|
>>> ~ Filters.command
|
||||||
|
@ -93,7 +110,7 @@ class BaseFilter(ABC):
|
||||||
(depends on the handler).
|
(depends on the handler).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = None
|
_name = None
|
||||||
data_filter = False
|
data_filter = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -106,9 +123,20 @@ class BaseFilter(ABC):
|
||||||
def __or__(self, other: 'BaseFilter') -> 'BaseFilter':
|
def __or__(self, other: 'BaseFilter') -> 'BaseFilter':
|
||||||
return MergedFilter(self, or_filter=other)
|
return MergedFilter(self, or_filter=other)
|
||||||
|
|
||||||
|
def __xor__(self, other: 'BaseFilter') -> 'BaseFilter':
|
||||||
|
return XORFilter(self, other)
|
||||||
|
|
||||||
def __invert__(self) -> 'BaseFilter':
|
def __invert__(self) -> 'BaseFilter':
|
||||||
return InvertedFilter(self)
|
return InvertedFilter(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> Optional[str]:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: Optional[str]) -> None:
|
||||||
|
self._name = name
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
|
# We do this here instead of in a __init__ so filter don't have to call __init__ or super()
|
||||||
if self.name is None:
|
if self.name is None:
|
||||||
|
@ -193,9 +221,14 @@ class InvertedFilter(UpdateFilter):
|
||||||
def filter(self, update: Update) -> bool:
|
def filter(self, update: Update) -> bool:
|
||||||
return not bool(self.f(update))
|
return not bool(self.f(update))
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
return "<inverted {}>".format(self.f)
|
return "<inverted {}>".format(self.f)
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for InvertedFilter')
|
||||||
|
|
||||||
|
|
||||||
class MergedFilter(UpdateFilter):
|
class MergedFilter(UpdateFilter):
|
||||||
"""Represents a filter consisting of two other filters.
|
"""Represents a filter consisting of two other filters.
|
||||||
|
@ -269,11 +302,43 @@ class MergedFilter(UpdateFilter):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
return "<{} {} {}>".format(
|
return "<{} {} {}>".format(
|
||||||
self.base_filter, "and" if self.and_filter else "or", self.and_filter or self.or_filter
|
self.base_filter, "and" if self.and_filter else "or", self.and_filter or self.or_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for MergedFilter')
|
||||||
|
|
||||||
|
|
||||||
|
class XORFilter(UpdateFilter):
|
||||||
|
"""Convenience filter acting as wrapper for :class:`MergedFilter` representing the an XOR gate
|
||||||
|
for two filters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_filter: Filter 1 of the merged filter.
|
||||||
|
xor_filter: Filter 2 of the merged filter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_filter: BaseFilter, xor_filter: BaseFilter):
|
||||||
|
self.base_filter = base_filter
|
||||||
|
self.xor_filter = xor_filter
|
||||||
|
self.merged_filter = (base_filter & ~xor_filter) | (~base_filter & xor_filter)
|
||||||
|
|
||||||
|
def filter(self, update: Update) -> Optional[Union[bool, Dict]]:
|
||||||
|
return self.merged_filter(update)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f'<{self.base_filter} xor {self.xor_filter}>'
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for XORFilter')
|
||||||
|
|
||||||
|
|
||||||
class _DiceEmoji(MessageFilter):
|
class _DiceEmoji(MessageFilter):
|
||||||
def __init__(self, emoji: str = None, name: str = None):
|
def __init__(self, emoji: str = None, name: str = None):
|
||||||
|
@ -1355,6 +1420,14 @@ officedocument.wordprocessingml.document")``-
|
||||||
return self.allow_empty
|
return self.allow_empty
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f'Filters.user({", ".join(str(s) for s in (self.usernames or self.user_ids))})'
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for Filters.user')
|
||||||
|
|
||||||
class via_bot(MessageFilter):
|
class via_bot(MessageFilter):
|
||||||
"""Filters messages to allow only those which are from specified via_bot ID(s) or
|
"""Filters messages to allow only those which are from specified via_bot ID(s) or
|
||||||
username(s).
|
username(s).
|
||||||
|
@ -1537,6 +1610,15 @@ officedocument.wordprocessingml.document")``-
|
||||||
return self.allow_empty
|
return self.allow_empty
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
entries = [str(s) for s in (self.usernames or self.bot_ids)]
|
||||||
|
return f'Filters.via_bot({", ".join(entries)})'
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for Filters.via_bot')
|
||||||
|
|
||||||
class chat(MessageFilter):
|
class chat(MessageFilter):
|
||||||
"""Filters messages to allow only those which are from a specified chat ID or username.
|
"""Filters messages to allow only those which are from a specified chat ID or username.
|
||||||
|
|
||||||
|
@ -1717,6 +1799,14 @@ officedocument.wordprocessingml.document")``-
|
||||||
return self.allow_empty
|
return self.allow_empty
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f'Filters.chat({", ".join(str(s) for s in (self.usernames or self.chat_ids))})'
|
||||||
|
|
||||||
|
@name.setter
|
||||||
|
def name(self, name: str) -> NoReturn:
|
||||||
|
raise RuntimeError('Cannot set name for Filters.chat')
|
||||||
|
|
||||||
class _Invoice(MessageFilter):
|
class _Invoice(MessageFilter):
|
||||||
name = 'Filters.invoice'
|
name = 'Filters.invoice'
|
||||||
|
|
||||||
|
|
|
@ -1049,6 +1049,22 @@ class TestFilters:
|
||||||
update.message.from_user.username = user
|
update.message.from_user.username = user
|
||||||
assert not f(update)
|
assert not f(update)
|
||||||
|
|
||||||
|
def test_filters_user_repr(self):
|
||||||
|
f = Filters.user([1, 2])
|
||||||
|
assert str(f) == 'Filters.user(1, 2)'
|
||||||
|
f.remove_user_ids(1)
|
||||||
|
f.remove_user_ids(2)
|
||||||
|
assert str(f) == 'Filters.user()'
|
||||||
|
f.add_usernames('@foobar')
|
||||||
|
assert str(f) == 'Filters.user(foobar)'
|
||||||
|
f.add_usernames('@barfoo')
|
||||||
|
assert str(f).startswith('Filters.user(')
|
||||||
|
# we don't know th exact order
|
||||||
|
assert 'barfoo' in str(f) and 'foobar' in str(f)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
f.name = 'foo'
|
||||||
|
|
||||||
def test_filters_chat_init(self):
|
def test_filters_chat_init(self):
|
||||||
with pytest.raises(RuntimeError, match='in conjunction with'):
|
with pytest.raises(RuntimeError, match='in conjunction with'):
|
||||||
Filters.chat(chat_id=1, username='chat')
|
Filters.chat(chat_id=1, username='chat')
|
||||||
|
@ -1174,6 +1190,22 @@ class TestFilters:
|
||||||
update.message.chat.username = chat
|
update.message.chat.username = chat
|
||||||
assert not f(update)
|
assert not f(update)
|
||||||
|
|
||||||
|
def test_filters_chat_repr(self):
|
||||||
|
f = Filters.chat([1, 2])
|
||||||
|
assert str(f) == 'Filters.chat(1, 2)'
|
||||||
|
f.remove_chat_ids(1)
|
||||||
|
f.remove_chat_ids(2)
|
||||||
|
assert str(f) == 'Filters.chat()'
|
||||||
|
f.add_usernames('@foobar')
|
||||||
|
assert str(f) == 'Filters.chat(foobar)'
|
||||||
|
f.add_usernames('@barfoo')
|
||||||
|
assert str(f).startswith('Filters.chat(')
|
||||||
|
# we don't know th exact order
|
||||||
|
assert 'barfoo' in str(f) and 'foobar' in str(f)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
f.name = 'foo'
|
||||||
|
|
||||||
def test_filters_invoice(self, update):
|
def test_filters_invoice(self, update):
|
||||||
assert not Filters.invoice(update)
|
assert not Filters.invoice(update)
|
||||||
update.message.invoice = 'test'
|
update.message.invoice = 'test'
|
||||||
|
@ -1294,6 +1326,63 @@ class TestFilters:
|
||||||
'Filters.entity(mention)>>'
|
'Filters.entity(mention)>>'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_xor_filters(self, update):
|
||||||
|
update.message.text = 'test'
|
||||||
|
update.effective_user.id = 123
|
||||||
|
assert not (Filters.text ^ Filters.user(123))(update)
|
||||||
|
update.message.text = None
|
||||||
|
update.effective_user.id = 1234
|
||||||
|
assert not (Filters.text ^ Filters.user(123))(update)
|
||||||
|
update.message.text = 'test'
|
||||||
|
assert (Filters.text ^ Filters.user(123))(update)
|
||||||
|
update.message.text = None
|
||||||
|
update.effective_user.id = 123
|
||||||
|
assert (Filters.text ^ Filters.user(123))(update)
|
||||||
|
|
||||||
|
def test_xor_filters_repr(self, update):
|
||||||
|
assert str(Filters.text ^ Filters.user(123)) == '<Filters.text xor Filters.user(123)>'
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
(Filters.text ^ Filters.user(123)).name = 'foo'
|
||||||
|
|
||||||
|
def test_and_xor_filters(self, update):
|
||||||
|
update.message.text = 'test'
|
||||||
|
update.message.forward_date = datetime.datetime.utcnow()
|
||||||
|
assert (Filters.forwarded & (Filters.text ^ Filters.user(123)))(update)
|
||||||
|
update.message.text = None
|
||||||
|
update.effective_user.id = 123
|
||||||
|
assert (Filters.forwarded & (Filters.text ^ Filters.user(123)))(update)
|
||||||
|
update.message.text = 'test'
|
||||||
|
assert not (Filters.forwarded & (Filters.text ^ Filters.user(123)))(update)
|
||||||
|
update.message.forward_date = None
|
||||||
|
update.message.text = None
|
||||||
|
update.effective_user.id = 123
|
||||||
|
assert not (Filters.forwarded & (Filters.text ^ Filters.user(123)))(update)
|
||||||
|
update.message.text = 'test'
|
||||||
|
update.effective_user.id = 456
|
||||||
|
assert not (Filters.forwarded & (Filters.text ^ Filters.user(123)))(update)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
str(Filters.forwarded & (Filters.text ^ Filters.user(123)))
|
||||||
|
== '<Filters.forwarded and <Filters.text xor '
|
||||||
|
'Filters.user(123)>>'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_xor_regex_filters(self, update):
|
||||||
|
SRE_TYPE = type(re.match("", ""))
|
||||||
|
update.message.text = 'test'
|
||||||
|
update.message.forward_date = datetime.datetime.utcnow()
|
||||||
|
assert not (Filters.forwarded ^ Filters.regex('^test$'))(update)
|
||||||
|
update.message.forward_date = None
|
||||||
|
result = (Filters.forwarded ^ Filters.regex('^test$'))(update)
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
matches = result['matches']
|
||||||
|
assert isinstance(matches, list)
|
||||||
|
assert type(matches[0]) is SRE_TYPE
|
||||||
|
update.message.forward_date = datetime.datetime.utcnow()
|
||||||
|
update.message.text = None
|
||||||
|
assert (Filters.forwarded ^ Filters.regex('^test$'))(update) is True
|
||||||
|
|
||||||
def test_inverted_filters(self, update):
|
def test_inverted_filters(self, update):
|
||||||
update.message.text = '/test'
|
update.message.text = '/test'
|
||||||
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
|
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
|
||||||
|
@ -1304,6 +1393,11 @@ class TestFilters:
|
||||||
assert not Filters.command(update)
|
assert not Filters.command(update)
|
||||||
assert (~Filters.command)(update)
|
assert (~Filters.command)(update)
|
||||||
|
|
||||||
|
def test_inverted_filters_repr(self, update):
|
||||||
|
assert str(~Filters.text) == '<inverted Filters.text>'
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
(~Filters.text).name = 'foo'
|
||||||
|
|
||||||
def test_inverted_and_filters(self, update):
|
def test_inverted_and_filters(self, update):
|
||||||
update.message.text = '/test'
|
update.message.text = '/test'
|
||||||
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
|
update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)]
|
||||||
|
@ -1398,6 +1492,10 @@ class TestFilters:
|
||||||
update.message.entities = []
|
update.message.entities = []
|
||||||
(Filters.command & raising_filter)(update)
|
(Filters.command & raising_filter)(update)
|
||||||
|
|
||||||
|
def test_merged_filters_repr(self, update):
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
(Filters.text & Filters.photo).name = 'foo'
|
||||||
|
|
||||||
def test_merged_short_circuit_or(self, update, base_class):
|
def test_merged_short_circuit_or(self, update, base_class):
|
||||||
update.message.text = 'test'
|
update.message.text = 'test'
|
||||||
|
|
||||||
|
@ -1587,3 +1685,19 @@ class TestFilters:
|
||||||
for user in users:
|
for user in users:
|
||||||
update.message.via_bot.username = user
|
update.message.via_bot.username = user
|
||||||
assert not f(update)
|
assert not f(update)
|
||||||
|
|
||||||
|
def test_filters_via_bot_repr(self):
|
||||||
|
f = Filters.via_bot([1, 2])
|
||||||
|
assert str(f) == 'Filters.via_bot(1, 2)'
|
||||||
|
f.remove_bot_ids(1)
|
||||||
|
f.remove_bot_ids(2)
|
||||||
|
assert str(f) == 'Filters.via_bot()'
|
||||||
|
f.add_usernames('@foobar')
|
||||||
|
assert str(f) == 'Filters.via_bot(foobar)'
|
||||||
|
f.add_usernames('@barfoo')
|
||||||
|
assert str(f).startswith('Filters.via_bot(')
|
||||||
|
# we don't know th exact order
|
||||||
|
assert 'barfoo' in str(f) and 'foobar' in str(f)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match='Cannot set name'):
|
||||||
|
f.name = 'foo'
|
||||||
|
|
Loading…
Add table
Reference in a new issue