Allow Input of Type Sticker for Several Methods (#4616)

This commit is contained in:
Bibo-Joshi 2024-12-29 20:16:46 +01:00 committed by GitHub
parent df20e49db1
commit a6cd9c5292
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 190 additions and 55 deletions

View file

@ -181,8 +181,8 @@ markers = [
"req",
]
asyncio_mode = "auto"
log_format = "%(funcName)s - Line %(lineno)d - %(message)s"
# log_level = "DEBUG" # uncomment to see DEBUG logs
log_cli_format = "%(funcName)s - Line %(lineno)d - %(message)s"
# log_cli_level = "DEBUG" # uncomment to see DEBUG logs
# MYPY:
[tool.mypy]

View file

@ -6622,7 +6622,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
async def set_sticker_position_in_set(
self,
sticker: str,
sticker: Union[str, "Sticker"],
position: int,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -6634,7 +6634,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
"""Use this method to move a sticker in a set created by the bot to a specific position.
Args:
sticker (:obj:`str`): File identifier of the sticker.
sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or
the sticker object.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
position (:obj:`int`): New sticker position in the set, zero-based.
Returns:
@ -6644,7 +6648,10 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
:class:`telegram.error.TelegramError`
"""
data: JSONDict = {"sticker": sticker, "position": position}
data: JSONDict = {
"sticker": sticker if isinstance(sticker, str) else sticker.file_id,
"position": position,
}
return await self._post(
"setStickerPositionInSet",
data,
@ -6749,7 +6756,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
async def delete_sticker_from_set(
self,
sticker: str,
sticker: Union[str, "Sticker"],
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
@ -6760,7 +6767,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
"""Use this method to delete a sticker from a set created by the bot.
Args:
sticker (:obj:`str`): File identifier of the sticker.
sticker (:obj:`str` | :class:`telegram.Sticker`): File identifier of the sticker or
the sticker object.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
Returns:
:obj:`bool`: On success, :obj:`True` is returned.
@ -6769,7 +6780,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
:class:`telegram.error.TelegramError`
"""
data: JSONDict = {"sticker": sticker}
data: JSONDict = {"sticker": sticker if isinstance(sticker, str) else sticker.file_id}
return await self._post(
"deleteStickerFromSet",
data,
@ -6937,7 +6948,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
async def set_sticker_emoji_list(
self,
sticker: str,
sticker: Union[str, "Sticker"],
emoji_list: Sequence[str],
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -6953,7 +6964,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
.. versionadded:: 20.2
Args:
sticker (:obj:`str`): File identifier of the sticker.
sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or
the sticker object.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
emoji_list (Sequence[:obj:`str`]): A sequence of
:tg-const:`telegram.constants.StickerLimit.MIN_STICKER_EMOJI`-
:tg-const:`telegram.constants.StickerLimit.MAX_STICKER_EMOJI` emoji associated with
@ -6965,7 +6980,10 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
Raises:
:class:`telegram.error.TelegramError`
"""
data: JSONDict = {"sticker": sticker, "emoji_list": emoji_list}
data: JSONDict = {
"sticker": sticker if isinstance(sticker, str) else sticker.file_id,
"emoji_list": emoji_list,
}
return await self._post(
"setStickerEmojiList",
data,
@ -6978,7 +6996,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
async def set_sticker_keywords(
self,
sticker: str,
sticker: Union[str, "Sticker"],
keywords: Optional[Sequence[str]] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -6994,7 +7012,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
.. versionadded:: 20.2
Args:
sticker (:obj:`str`): File identifier of the sticker.
sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or
the sticker object.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
keywords (Sequence[:obj:`str`]): A sequence of
0-:tg-const:`telegram.constants.StickerLimit.MAX_SEARCH_KEYWORDS` search keywords
for the sticker with total length up to
@ -7006,7 +7028,10 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
Raises:
:class:`telegram.error.TelegramError`
"""
data: JSONDict = {"sticker": sticker, "keywords": keywords}
data: JSONDict = {
"sticker": sticker if isinstance(sticker, str) else sticker.file_id,
"keywords": keywords,
}
return await self._post(
"setStickerKeywords",
data,
@ -7019,7 +7044,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
async def set_sticker_mask_position(
self,
sticker: str,
sticker: Union[str, "Sticker"],
mask_position: Optional[MaskPosition] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -7035,7 +7060,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
.. versionadded:: 20.2
Args:
sticker (:obj:`str`): File identifier of the sticker.
sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or
the sticker object.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
mask_position (:class:`telegram.MaskPosition`, optional): A object with the position
where the mask should be placed on faces. Omit the parameter to remove the mask
position.
@ -7046,7 +7075,10 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
Raises:
:class:`telegram.error.TelegramError`
"""
data: JSONDict = {"sticker": sticker, "mask_position": mask_position}
data: JSONDict = {
"sticker": sticker if isinstance(sticker, str) else sticker.file_id,
"mask_position": mask_position,
}
return await self._post(
"setStickerMaskPosition",
data,
@ -9248,7 +9280,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
self,
user_id: int,
name: str,
old_sticker: str,
old_sticker: Union[str, "Sticker"],
sticker: "InputSticker",
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -9266,7 +9298,11 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
Args:
user_id (:obj:`int`): User identifier of the sticker set owner.
name (:obj:`str`): Sticker set name.
old_sticker (:obj:`str`): File identifier of the replaced sticker.
old_sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the replaced
sticker or the sticker object itself.
.. versionchanged:: NEXT.VERSION
Accepts also :class:`telegram.Sticker` instances.
sticker (:class:`telegram.InputSticker`): An object with information about the added
sticker. If exactly the same sticker had already been added to the set, then the
set remains unchanged.
@ -9280,7 +9316,7 @@ CUSTOM_EMOJI_IDENTIFIER_LIMIT` custom emoji identifiers can be specified.
data: JSONDict = {
"user_id": user_id,
"name": name,
"old_sticker": old_sticker,
"old_sticker": old_sticker if isinstance(old_sticker, str) else old_sticker.file_id,
"sticker": sticker,
}

View file

@ -1426,7 +1426,7 @@ class ExtBot(Bot, Generic[RLARGS]):
async def delete_sticker_from_set(
self,
sticker: str,
sticker: Union[str, "Sticker"],
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
@ -3660,7 +3660,7 @@ class ExtBot(Bot, Generic[RLARGS]):
async def set_sticker_position_in_set(
self,
sticker: str,
sticker: Union[str, "Sticker"],
position: int,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -4114,7 +4114,7 @@ class ExtBot(Bot, Generic[RLARGS]):
async def set_sticker_emoji_list(
self,
sticker: str,
sticker: Union[str, "Sticker"],
emoji_list: Sequence[str],
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -4136,7 +4136,7 @@ class ExtBot(Bot, Generic[RLARGS]):
async def set_sticker_keywords(
self,
sticker: str,
sticker: Union[str, "Sticker"],
keywords: Optional[Sequence[str]] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -4158,7 +4158,7 @@ class ExtBot(Bot, Generic[RLARGS]):
async def set_sticker_mask_position(
self,
sticker: str,
sticker: Union[str, "Sticker"],
mask_position: Optional[MaskPosition] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
@ -4250,7 +4250,7 @@ class ExtBot(Bot, Generic[RLARGS]):
self,
user_id: int,
name: str,
old_sticker: str,
old_sticker: Union[str, "Sticker"],
sticker: "InputSticker",
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,

View file

@ -699,6 +699,54 @@ class TestStickerSetWithoutRequest(StickerSetTestBase):
monkeypatch.setattr(sticker.get_bot(), "get_file", make_assertion)
assert await sticker.get_file()
async def test_delete_sticker_from_set_sticker_input(self, offline_bot, sticker, monkeypatch):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.delete_sticker_from_set(sticker)
async def test_replace_sticker_in_set_sticker_input(self, offline_bot, sticker, monkeypatch):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["old_sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.replace_sticker_in_set(
user_id=1, name="name", sticker="sticker", old_sticker=sticker
)
async def test_set_sticker_emoji_list_sticker_input(self, offline_bot, sticker, monkeypatch):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.set_sticker_emoji_list(sticker, ["emoji"])
async def test_set_sticker_mask_position_sticker_input(
self, offline_bot, sticker, monkeypatch
):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.set_sticker_mask_position(sticker, MaskPosition("eyes", 1, 2, 3))
async def test_set_sticker_position_in_set_sticker_input(
self, offline_bot, sticker, monkeypatch
):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.set_sticker_position_in_set(sticker, 1)
async def test_set_sticker_keywords_sticker_input(self, offline_bot, sticker, monkeypatch):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return request_data.json_parameters["sticker"] == sticker.file_id
monkeypatch.setattr(offline_bot.request, "post", make_assertion)
assert await offline_bot.set_sticker_keywords(sticker, ["keyword"])
@pytest.mark.xdist_group("stickerset")
class TestStickerSetWithRequest:

View file

@ -37,6 +37,7 @@ from telegram import (
InputTextMessageContent,
LinkPreviewOptions,
ReplyParameters,
Sticker,
TelegramObject,
)
from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue
@ -317,6 +318,16 @@ def build_kwargs(
kws["error_message"] = "error"
elif name == "options":
kws[name] = ["option1", "option2"]
elif name in ("sticker", "old_sticker"):
kws[name] = Sticker(
file_id="file_id",
file_unique_id="file_unique_id",
width=1,
height=1,
is_animated=False,
is_video=False,
type="regular",
)
else:
kws[name] = True

View file

@ -38,6 +38,7 @@ from tests.test_official.helpers import (
_get_params_base,
_unionizer,
cached_type_hints,
extract_mappings,
resolve_forward_refs_in_type,
wrap_with_none,
)
@ -144,7 +145,7 @@ def check_param_type(
)
# CHECKING:
# Each branch manipulates the `mapped_type` (except for 4) ) to match the `ptb_annotation`.
# Each branch manipulates the `mapped_type` (except for 5) ) to match the `ptb_annotation`.
# 1) HANDLING ARRAY TYPES:
# Now let's do the checking, starting with "Array of ..." types.
@ -174,9 +175,11 @@ def check_param_type(
# 2) HANDLING OTHER TYPES:
# Special case for send_* methods where we accept more types than the official API:
elif ptb_param.name in PTCE.ADDITIONAL_TYPES and obj.__name__.startswith("send"):
log("Checking that `%s` has an additional argument!\n", ptb_param.name)
mapped_type = mapped_type | PTCE.ADDITIONAL_TYPES[ptb_param.name]
elif additional_types := extract_mappings(PTCE.ADDITIONAL_TYPES, obj, ptb_param.name):
log("Checking that `%s` accepts additional types for some parameters!\n", obj.__name__)
for at in additional_types:
log("Checking that `%s` is an additional type for `%s`!\n", at, ptb_param.name)
mapped_type = mapped_type | at
# 3) HANDLING DATETIMES:
elif (
@ -205,10 +208,9 @@ def check_param_type(
# 5) COMPLEX TYPES:
# Some types are too complicated, so we replace our annotation with a simpler type:
elif any(ptb_param.name in key for key in PTCE.COMPLEX_TYPES):
elif overrides := extract_mappings(PTCE.COMPLEX_TYPES, obj, ptb_param.name):
exception_type = overrides[0]
log("Converting `%s` to a simpler type!\n", ptb_param.name)
for (param_name, is_expected_class), exception_type in PTCE.COMPLEX_TYPES.items():
if ptb_param.name == param_name and is_class is is_expected_class:
ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj)
# 6) HANDLING DEFAULTS PARAMETERS:

View file

@ -35,9 +35,11 @@ GLOBALLY_IGNORED_PARAMETERS = {
class ParamTypeCheckingExceptions:
# Types for certain parameters accepted by PTB but not in the official API
# structure: method/class_name/regex: {param_name/regex: type}
ADDITIONAL_TYPES = {
"photo": PhotoSize,
"video": Video,
r"send_\w*": {
"photo$": PhotoSize,
"video$": Video,
"video_note": VideoNote,
"audio": Audio,
"document": Document,
@ -45,6 +47,13 @@ class ParamTypeCheckingExceptions:
"voice": Voice,
"sticker": Sticker,
"gift_id": Gift,
},
"(delete|set)_sticker.*": {
"sticker$": Sticker,
},
"replace_sticker_in_set": {
"old_sticker$": Sticker,
},
}
# TODO: Look into merging this with COMPLEX_TYPES
@ -61,19 +70,29 @@ class ParamTypeCheckingExceptions:
}
# Special cases for other parameters that accept more types than the official API, and are
# too complex to compare/predict with official API:
# too complex to compare/predict with official API
# structure: class/method_name: {param_name: reduced form of annotation}
COMPLEX_TYPES = (
{ # (param_name, is_class (i.e appears in a class?)): reduced form of annotation
("correct_option_id", False): int, # actual: Literal
("file_id", False): str, # actual: Union[str, objs_with_file_id_attr]
("invite_link", False): str, # actual: Union[str, ChatInviteLink]
("provider_data", False): str, # actual: Union[str, obj]
("callback_data", True): str, # actual: Union[str, obj]
("media", True): str, # actual: Union[str, InputMedia*, FileInput]
(
"data",
True,
): str, # actual: Union[IdDocumentData, PersonalDetails, ResidentialAddress]
"send_poll": {"correct_option_id": int}, # actual: Literal
"get_file": {
"file_id": str, # actual: Union[str, objs_with_file_id_attr]
},
r"\w+invite_link": {
"invite_link": str, # actual: Union[str, ChatInviteLink]
},
"send_invoice|create_invoice_link": {
"provider_data": str, # actual: Union[str, obj]
},
"InlineKeyboardButton": {
"callback_data": str, # actual: Union[str, obj]
},
"Input(Paid)?Media.*": {
"media": str, # actual: Union[str, InputMedia*, FileInput]
},
"EncryptedPassportElement": {
"data": str, # actual: Union[IdDocumentData, PersonalDetails, ResidentialAddress]
},
}
)

View file

@ -21,7 +21,7 @@
import functools
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, _eval_type, get_type_hints
from typing import TYPE_CHECKING, Any, Optional, TypeVar, _eval_type, get_type_hints
from bs4 import PageElement, Tag
@ -110,3 +110,22 @@ def cached_type_hints(obj: Any, is_class: bool) -> dict[str, Any]:
def resolve_forward_refs_in_type(obj: type) -> type:
"""Resolves forward references in a type hint."""
return _eval_type(obj, localns=tg_objects, globalns=None)
T = TypeVar("T")
def extract_mappings(
exceptions: dict[str, dict[str, T]], obj: object, param_name: str
) -> Optional[list[T]]:
mappings = (
mapping for pattern, mapping in exceptions.items() if (re.match(pattern, obj.__name__))
)
out = [
value
for mapping in mappings
for key, value in mapping.items()
if re.match(key, param_name)
]
return None or out