diff --git a/telegram/ext/_precheckoutqueryhandler.py b/telegram/ext/_precheckoutqueryhandler.py index 25843aa94..3c193cb73 100644 --- a/telegram/ext/_precheckoutqueryhandler.py +++ b/telegram/ext/_precheckoutqueryhandler.py @@ -19,9 +19,16 @@ """This module contains the PreCheckoutQueryHandler class.""" +import re +from typing import Optional, Pattern, TypeVar, Union + from telegram import Update +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram._utils.types import DVType from telegram.ext._basehandler import BaseHandler -from telegram.ext._utils.types import CCT +from telegram.ext._utils.types import CCT, HandlerCallback + +RT = TypeVar("RT") class PreCheckoutQueryHandler(BaseHandler[Update, CCT]): @@ -48,14 +55,32 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. .. seealso:: :wiki:`Concurrency` + pattern (:obj:`str` | :func:`re.Pattern `, optional): Optional. Regex pattern + to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against. + + .. versionadded:: NEXT.VERSION Attributes: callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run in a blocking way.. + pattern (:obj:`str` | :func:`re.Pattern `, optional): Optional. Regex pattern + to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against. + + .. versionadded:: NEXT.VERSION """ - __slots__ = () + __slots__ = ("pattern",) + + def __init__( + self, + callback: HandlerCallback[Update, CCT, RT], + block: DVType[bool] = DEFAULT_TRUE, + pattern: Optional[Union[str, Pattern[str]]] = None, + ): + super().__init__(callback, block=block) + + self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None def check_update(self, update: object) -> bool: """Determines whether an update should be passed to this handler's :attr:`callback`. @@ -67,4 +92,11 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]): :obj:`bool` """ - return isinstance(update, Update) and bool(update.pre_checkout_query) + if isinstance(update, Update) and update.pre_checkout_query: + invoice_payload = update.pre_checkout_query.invoice_payload + if self.pattern: + if self.pattern.match(invoice_payload): + return True + else: + return True + return False diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index c1f9aee94..ae34e6b8f 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -75,6 +75,7 @@ __all__ = ( "Sticker", "STORY", "SUCCESSFUL_PAYMENT", + "SuccessfulPayment", "SenderChat", "StatusUpdate", "TEXT", @@ -2265,14 +2266,45 @@ STORY = _Story(name="filters.STORY") """ -class _SuccessfulPayment(MessageFilter): - __slots__ = () +class SuccessfulPayment(MessageFilter): + """Successful Payment Messages. If a list of invoice payloads is passed, it filters + messages to only allow those whose `invoice_payload` is appearing in the given list. + + Examples: + `MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)` + + .. seealso:: + :attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT` + + Args: + invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which + invoice payloads to allow. Only exact matches are allowed. If not + specified, will allow any invoice payload. + + .. versionadded:: NEXT.VERSION + """ + + __slots__ = ("invoice_payloads",) + + def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None): + self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads + super().__init__( + name=f"filters.SuccessfulPayment({invoice_payloads})" + if invoice_payloads + else "filters.SUCCESSFUL_PAYMENT" + ) def filter(self, message: Message) -> bool: - return bool(message.successful_payment) + if self.invoice_payloads is None: + return bool(message.successful_payment) + return ( + payment.invoice_payload in self.invoice_payloads + if (payment := message.successful_payment) + else False + ) -SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT") +SUCCESSFUL_PAYMENT = SuccessfulPayment() """Messages that contain :attr:`telegram.Message.successful_payment`.""" diff --git a/tests/ext/test_filters.py b/tests/ext/test_filters.py index f0812db66..0ac7023bb 100644 --- a/tests/ext/test_filters.py +++ b/tests/ext/test_filters.py @@ -31,6 +31,7 @@ from telegram import ( Message, MessageEntity, Sticker, + SuccessfulPayment, Update, User, ) @@ -1877,6 +1878,24 @@ class TestFilters: update.message.successful_payment = "test" assert filters.SUCCESSFUL_PAYMENT.check_update(update) + def test_filters_successful_payment_payloads(self, update): + assert not filters.SuccessfulPayment(("custom-payload",)).check_update(update) + assert not filters.SuccessfulPayment().check_update(update) + + update.message.successful_payment = SuccessfulPayment( + "USD", 100, "custom-payload", "123", "123" + ) + assert filters.SuccessfulPayment(("custom-payload",)).check_update(update) + assert filters.SuccessfulPayment().check_update(update) + assert not filters.SuccessfulPayment(["test1"]).check_update(update) + + def test_filters_successful_payment_repr(self): + f = filters.SuccessfulPayment() + assert str(f) == "filters.SUCCESSFUL_PAYMENT" + + f = filters.SuccessfulPayment(["payload1", "payload2"]) + assert str(f) == "filters.SuccessfulPayment(['payload1', 'payload2'])" + def test_filters_passport_data(self, update): assert not filters.PASSPORT_DATA.check_update(update) update.message.passport_data = "test" diff --git a/tests/ext/test_precheckoutqueryhandler.py b/tests/ext/test_precheckoutqueryhandler.py index c22e058ff..37ba8ba1d 100644 --- a/tests/ext/test_precheckoutqueryhandler.py +++ b/tests/ext/test_precheckoutqueryhandler.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio +import re import pytest @@ -69,12 +70,15 @@ def false_update(request): @pytest.fixture(scope="class") def pre_checkout_query(): - return Update( + update = Update( 1, pre_checkout_query=PreCheckoutQuery( "id", User(1, "test user", False), "EUR", 223, "invoice_payload" ), ) + update._unfreeze() + update.pre_checkout_query._unfreeze() + return update class TestPreCheckoutQueryHandler: @@ -103,6 +107,23 @@ class TestPreCheckoutQueryHandler: and isinstance(update.pre_checkout_query, PreCheckoutQuery) ) + def test_with_pattern(self, pre_checkout_query): + handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*") + + assert handler.check_update(pre_checkout_query) + + pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here" + assert not handler.check_update(pre_checkout_query) + + def test_with_compiled_pattern(self, pre_checkout_query): + handler = PreCheckoutQueryHandler(self.callback, pattern=re.compile(r".*payload")) + + pre_checkout_query.pre_checkout_query.invoice_payload = "invoice_payload" + assert handler.check_update(pre_checkout_query) + + pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here" + assert not handler.check_update(pre_checkout_query) + def test_other_update_types(self, false_update): handler = PreCheckoutQueryHandler(self.callback) assert not handler.check_update(false_update)