Add Parameter pattern to PreCheckoutQueryHandler and filters.SuccessfulPayment (#4005)

This commit is contained in:
aelkheir 2024-01-02 20:35:38 +03:00 committed by GitHub
parent 7fcfad41a5
commit f3479cd170
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 8 deletions

View file

@ -19,9 +19,16 @@
"""This module contains the PreCheckoutQueryHandler class."""
import re
from typing import Optional, Pattern, TypeVar, Union
from telegram import Update
from telegram._utils.defaultvalue import DEFAULT_TRUE
from telegram._utils.types import DVType
from telegram.ext._basehandler import BaseHandler
from telegram.ext._utils.types import CCT
from telegram.ext._utils.types import CCT, HandlerCallback
RT = TypeVar("RT")
class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
@ -48,14 +55,32 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
:meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`.
.. seealso:: :wiki:`Concurrency`
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
Attributes:
callback (:term:`coroutine function`): The callback function for this handler.
block (:obj:`bool`): Determines whether the callback will run in a blocking way..
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
"""
__slots__ = ()
__slots__ = ("pattern",)
def __init__(
self,
callback: HandlerCallback[Update, CCT, RT],
block: DVType[bool] = DEFAULT_TRUE,
pattern: Optional[Union[str, Pattern[str]]] = None,
):
super().__init__(callback, block=block)
self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None
def check_update(self, update: object) -> bool:
"""Determines whether an update should be passed to this handler's :attr:`callback`.
@ -67,4 +92,11 @@ class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
:obj:`bool`
"""
return isinstance(update, Update) and bool(update.pre_checkout_query)
if isinstance(update, Update) and update.pre_checkout_query:
invoice_payload = update.pre_checkout_query.invoice_payload
if self.pattern:
if self.pattern.match(invoice_payload):
return True
else:
return True
return False

View file

@ -75,6 +75,7 @@ __all__ = (
"Sticker",
"STORY",
"SUCCESSFUL_PAYMENT",
"SuccessfulPayment",
"SenderChat",
"StatusUpdate",
"TEXT",
@ -2265,14 +2266,45 @@ STORY = _Story(name="filters.STORY")
"""
class _SuccessfulPayment(MessageFilter):
__slots__ = ()
class SuccessfulPayment(MessageFilter):
"""Successful Payment Messages. If a list of invoice payloads is passed, it filters
messages to only allow those whose `invoice_payload` is appearing in the given list.
Examples:
`MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)`
.. seealso::
:attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT`
Args:
invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which
invoice payloads to allow. Only exact matches are allowed. If not
specified, will allow any invoice payload.
.. versionadded:: NEXT.VERSION
"""
__slots__ = ("invoice_payloads",)
def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None):
self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads
super().__init__(
name=f"filters.SuccessfulPayment({invoice_payloads})"
if invoice_payloads
else "filters.SUCCESSFUL_PAYMENT"
)
def filter(self, message: Message) -> bool:
if self.invoice_payloads is None:
return bool(message.successful_payment)
return (
payment.invoice_payload in self.invoice_payloads
if (payment := message.successful_payment)
else False
)
SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT")
SUCCESSFUL_PAYMENT = SuccessfulPayment()
"""Messages that contain :attr:`telegram.Message.successful_payment`."""

View file

@ -31,6 +31,7 @@ from telegram import (
Message,
MessageEntity,
Sticker,
SuccessfulPayment,
Update,
User,
)
@ -1877,6 +1878,24 @@ class TestFilters:
update.message.successful_payment = "test"
assert filters.SUCCESSFUL_PAYMENT.check_update(update)
def test_filters_successful_payment_payloads(self, update):
assert not filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert not filters.SuccessfulPayment().check_update(update)
update.message.successful_payment = SuccessfulPayment(
"USD", 100, "custom-payload", "123", "123"
)
assert filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert filters.SuccessfulPayment().check_update(update)
assert not filters.SuccessfulPayment(["test1"]).check_update(update)
def test_filters_successful_payment_repr(self):
f = filters.SuccessfulPayment()
assert str(f) == "filters.SUCCESSFUL_PAYMENT"
f = filters.SuccessfulPayment(["payload1", "payload2"])
assert str(f) == "filters.SuccessfulPayment(['payload1', 'payload2'])"
def test_filters_passport_data(self, update):
assert not filters.PASSPORT_DATA.check_update(update)
update.message.passport_data = "test"

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import re
import pytest
@ -69,12 +70,15 @@ def false_update(request):
@pytest.fixture(scope="class")
def pre_checkout_query():
return Update(
update = Update(
1,
pre_checkout_query=PreCheckoutQuery(
"id", User(1, "test user", False), "EUR", 223, "invoice_payload"
),
)
update._unfreeze()
update.pre_checkout_query._unfreeze()
return update
class TestPreCheckoutQueryHandler:
@ -103,6 +107,23 @@ class TestPreCheckoutQueryHandler:
and isinstance(update.pre_checkout_query, PreCheckoutQuery)
)
def test_with_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*")
assert handler.check_update(pre_checkout_query)
pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)
def test_with_compiled_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=re.compile(r".*payload"))
pre_checkout_query.pre_checkout_query.invoice_payload = "invoice_payload"
assert handler.check_update(pre_checkout_query)
pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)
def test_other_update_types(self, false_update):
handler = PreCheckoutQueryHandler(self.callback)
assert not handler.check_update(false_update)