Add Parameter pattern to PreCheckoutQueryHandler and filters.SuccessfulPayment (#4005)

This commit is contained in:
aelkheir 2024-01-02 20:35:38 +03:00 committed by GitHub
parent 7fcfad41a5
commit f3479cd170
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 8 deletions

View file

@ -19,9 +19,16 @@
"""This module contains the PreCheckoutQueryHandler class.""" """This module contains the PreCheckoutQueryHandler class."""
import re
from typing import Optional, Pattern, TypeVar, Union
from telegram import Update from telegram import Update
from telegram._utils.defaultvalue import DEFAULT_TRUE
from telegram._utils.types import DVType
from telegram.ext._basehandler import BaseHandler from telegram.ext._basehandler import BaseHandler
from telegram.ext._utils.types import CCT from telegram.ext._utils.types import CCT, HandlerCallback
RT = TypeVar("RT")
class PreCheckoutQueryHandler(BaseHandler[Update, CCT]): class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
@ -48,14 +55,32 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
:meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`.
.. seealso:: :wiki:`Concurrency` .. seealso:: :wiki:`Concurrency`
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
Attributes: Attributes:
callback (:term:`coroutine function`): The callback function for this handler. callback (:term:`coroutine function`): The callback function for this handler.
block (:obj:`bool`): Determines whether the callback will run in a blocking way.. block (:obj:`bool`): Determines whether the callback will run in a blocking way..
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
""" """
__slots__ = () __slots__ = ("pattern",)
def __init__(
self,
callback: HandlerCallback[Update, CCT, RT],
block: DVType[bool] = DEFAULT_TRUE,
pattern: Optional[Union[str, Pattern[str]]] = None,
):
super().__init__(callback, block=block)
self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None
def check_update(self, update: object) -> bool: def check_update(self, update: object) -> bool:
"""Determines whether an update should be passed to this handler's :attr:`callback`. """Determines whether an update should be passed to this handler's :attr:`callback`.
@ -67,4 +92,11 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
:obj:`bool` :obj:`bool`
""" """
return isinstance(update, Update) and bool(update.pre_checkout_query) if isinstance(update, Update) and update.pre_checkout_query:
invoice_payload = update.pre_checkout_query.invoice_payload
if self.pattern:
if self.pattern.match(invoice_payload):
return True
else:
return True
return False

View file

@ -75,6 +75,7 @@ __all__ = (
"Sticker", "Sticker",
"STORY", "STORY",
"SUCCESSFUL_PAYMENT", "SUCCESSFUL_PAYMENT",
"SuccessfulPayment",
"SenderChat", "SenderChat",
"StatusUpdate", "StatusUpdate",
"TEXT", "TEXT",
@ -2265,14 +2266,45 @@ STORY = _Story(name="filters.STORY")
""" """
class _SuccessfulPayment(MessageFilter): class SuccessfulPayment(MessageFilter):
__slots__ = () """Successful Payment Messages. If a list of invoice payloads is passed, it filters
messages to only allow those whose `invoice_payload` is appearing in the given list.
Examples:
`MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)`
.. seealso::
:attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT`
Args:
invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which
invoice payloads to allow. Only exact matches are allowed. If not
specified, will allow any invoice payload.
.. versionadded:: NEXT.VERSION
"""
__slots__ = ("invoice_payloads",)
def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None):
self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads
super().__init__(
name=f"filters.SuccessfulPayment({invoice_payloads})"
if invoice_payloads
else "filters.SUCCESSFUL_PAYMENT"
)
def filter(self, message: Message) -> bool: def filter(self, message: Message) -> bool:
if self.invoice_payloads is None:
return bool(message.successful_payment) return bool(message.successful_payment)
return (
payment.invoice_payload in self.invoice_payloads
if (payment := message.successful_payment)
else False
)
SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT") SUCCESSFUL_PAYMENT = SuccessfulPayment()
"""Messages that contain :attr:`telegram.Message.successful_payment`.""" """Messages that contain :attr:`telegram.Message.successful_payment`."""

View file

@ -31,6 +31,7 @@ from telegram import (
Message, Message,
MessageEntity, MessageEntity,
Sticker, Sticker,
SuccessfulPayment,
Update, Update,
User, User,
) )
@ -1877,6 +1878,24 @@ class TestFilters:
update.message.successful_payment = "test" update.message.successful_payment = "test"
assert filters.SUCCESSFUL_PAYMENT.check_update(update) assert filters.SUCCESSFUL_PAYMENT.check_update(update)
def test_filters_successful_payment_payloads(self, update):
assert not filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert not filters.SuccessfulPayment().check_update(update)
update.message.successful_payment = SuccessfulPayment(
"USD", 100, "custom-payload", "123", "123"
)
assert filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert filters.SuccessfulPayment().check_update(update)
assert not filters.SuccessfulPayment(["test1"]).check_update(update)
def test_filters_successful_payment_repr(self):
f = filters.SuccessfulPayment()
assert str(f) == "filters.SUCCESSFUL_PAYMENT"
f = filters.SuccessfulPayment(["payload1", "payload2"])
assert str(f) == "filters.SuccessfulPayment(['payload1', 'payload2'])"
def test_filters_passport_data(self, update): def test_filters_passport_data(self, update):
assert not filters.PASSPORT_DATA.check_update(update) assert not filters.PASSPORT_DATA.check_update(update)
update.message.passport_data = "test" update.message.passport_data = "test"

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio import asyncio
import re
import pytest import pytest
@ -69,12 +70,15 @@ def false_update(request):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def pre_checkout_query(): def pre_checkout_query():
return Update( update = Update(
1, 1,
pre_checkout_query=PreCheckoutQuery( pre_checkout_query=PreCheckoutQuery(
"id", User(1, "test user", False), "EUR", 223, "invoice_payload" "id", User(1, "test user", False), "EUR", 223, "invoice_payload"
), ),
) )
update._unfreeze()
update.pre_checkout_query._unfreeze()
return update
class TestPreCheckoutQueryHandler: class TestPreCheckoutQueryHandler:
@ -103,6 +107,23 @@ class TestPreCheckoutQueryHandler:
and isinstance(update.pre_checkout_query, PreCheckoutQuery) and isinstance(update.pre_checkout_query, PreCheckoutQuery)
) )
def test_with_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*")
assert handler.check_update(pre_checkout_query)
pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)
def test_with_compiled_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=re.compile(r".*payload"))
pre_checkout_query.pre_checkout_query.invoice_payload = "invoice_payload"
assert handler.check_update(pre_checkout_query)
pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)
def test_other_update_types(self, false_update): def test_other_update_types(self, false_update):
handler = PreCheckoutQueryHandler(self.callback) handler = PreCheckoutQueryHandler(self.callback)
assert not handler.check_update(false_update) assert not handler.check_update(false_update)