python-telegram-bot/tests/auxil/bot_method_checks.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

664 lines
28 KiB
Python
Raw Normal View History

#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
2024-02-19 20:06:25 +01:00
# Copyright (C) 2015-2024
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""Provides functions to test both methods."""
import datetime as dtm
import functools
import inspect
import re
from collections.abc import Collection, Iterable
from typing import Any, Callable, Optional
import pytest
import telegram # for ForwardRef resolution
from telegram import (
Bot,
ChatPermissions,
File,
InlineQueryResultArticle,
InlineQueryResultCachedPhoto,
InputMediaPhoto,
InputTextMessageContent,
LinkPreviewOptions,
ReplyParameters,
TelegramObject,
)
from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue
from telegram.constants import InputMediaType
from telegram.ext import Defaults, ExtBot
from telegram.request import RequestData
2023-03-25 19:18:04 +01:00
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
if TEST_WITH_OPT_DEPS:
import pytz
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('(?P<class_name>\w+)'\)")
""" A pattern to find a class name in a ForwardRef typing annotation.
Class name (in a named group) is surrounded by parentheses and single quotes.
"""
def check_shortcut_signature(
shortcut: Callable,
bot_method: Callable,
shortcut_kwargs: list[str],
additional_kwargs: list[str],
annotation_overrides: Optional[dict[str, tuple[Any, Any]]] = None,
) -> bool:
"""
Checks that the signature of a shortcut matches the signature of the underlying bot method.
Args:
shortcut: The shortcut, e.g. :meth:`telegram.Message.reply_text`
bot_method: The bot method, e.g. :meth:`telegram.Bot.send_message`
shortcut_kwargs: The kwargs passed by the shortcut directly, e.g. ``chat_id``
additional_kwargs: Additional kwargs of the shortcut that the bot method doesn't have, e.g.
``quote``.
annotation_overrides: A dictionary of exceptions for the annotation comparison. The key is
the name of the argument, the value is a tuple of the expected annotation and
the default value. E.g. ``{'parse_mode': (str, 'None')}``.
Returns:
:obj:`bool`: Whether or not the signature matches.
"""
annotation_overrides = annotation_overrides or {}
def resolve_class(class_name: str) -> Optional[type]:
"""Attempts to resolve a PTB class (telegram module only) from a ForwardRef.
E.g. resolves <class 'telegram._files.sticker.StickerSet'> from "StickerSet".
Returns a class on success, :obj:`None` if nothing could be resolved.
"""
for module in telegram, telegram.request:
cls = getattr(module, class_name, None)
if cls is not None:
return cls
return None # for ruff
shortcut_sig = inspect.signature(shortcut)
effective_shortcut_args = set(shortcut_sig.parameters.keys()).difference(additional_kwargs)
effective_shortcut_args.discard("self")
bot_sig = inspect.signature(bot_method)
expected_args = set(bot_sig.parameters.keys()).difference(shortcut_kwargs)
expected_args.discard("self")
len_expected = len(expected_args)
len_effective = len(effective_shortcut_args)
if len_expected > len_effective:
raise Exception(
f"Shortcut signature is missing {len_expected - len_effective} arguments "
f"of the underlying Bot method: {expected_args - effective_shortcut_args}"
)
if len_expected < len_effective:
raise Exception(
f"Shortcut signature has {len_effective - len_expected} additional arguments "
f"that the Bot method doesn't have: {effective_shortcut_args - expected_args}"
)
# TODO: Also check annotation of return type. Would currently be a hassle b/c typing doesn't
# resolve `ForwardRef('Type')` to `Type`. For now we rely on MyPy, which probably allows the
# shortcuts to return more specific types than the bot method, but it's only annotations after
# all
for kwarg in effective_shortcut_args:
expected_kind = bot_sig.parameters[kwarg].kind
if shortcut_sig.parameters[kwarg].kind != expected_kind:
raise Exception(f"Argument {kwarg} must be of kind {expected_kind}.")
if kwarg in annotation_overrides:
if shortcut_sig.parameters[kwarg].annotation != annotation_overrides[kwarg][0]:
raise Exception(
f"For argument {kwarg} I expected {annotation_overrides[kwarg]}, "
f"but got {shortcut_sig.parameters[kwarg].annotation}"
)
continue
if bot_sig.parameters[kwarg].annotation != shortcut_sig.parameters[kwarg].annotation:
if FORWARD_REF_PATTERN.search(str(shortcut_sig.parameters[kwarg])):
# If a shortcut signature contains a ForwardRef, the simple comparison of
# annotations can fail. Try and resolve the .__args__, then compare them.
for shortcut_arg, bot_arg in zip(
shortcut_sig.parameters[kwarg].annotation.__args__,
bot_sig.parameters[kwarg].annotation.__args__,
):
shortcut_arg_to_check = shortcut_arg # for ruff
match = FORWARD_REF_PATTERN.search(str(shortcut_arg))
if match:
shortcut_arg_to_check = resolve_class(match.group("class_name"))
if shortcut_arg_to_check != bot_arg:
raise Exception(
f"For argument {kwarg} I expected "
f"{bot_sig.parameters[kwarg].annotation}, but "
f"got {shortcut_sig.parameters[kwarg].annotation}."
f"Comparison of {shortcut_arg} and {bot_arg} failed."
)
elif isinstance(bot_sig.parameters[kwarg].annotation, type):
if bot_sig.parameters[kwarg].annotation.__name__ != str(
shortcut_sig.parameters[kwarg].annotation
):
raise Exception(
f"For argument {kwarg} I expected {bot_sig.parameters[kwarg].annotation}, "
f"but got {shortcut_sig.parameters[kwarg].annotation}"
)
else:
raise Exception(
f"For argument {kwarg} I expected {bot_sig.parameters[kwarg].annotation},"
f"but got {shortcut_sig.parameters[kwarg].annotation}"
)
bot_method_sig = inspect.signature(bot_method)
shortcut_sig = inspect.signature(shortcut)
for arg in expected_args:
if arg in annotation_overrides:
if shortcut_sig.parameters[arg].default == annotation_overrides[arg][1]:
continue
raise Exception(
f"For argument {arg} I expected default {annotation_overrides[arg][1]}, "
f"but got {shortcut_sig.parameters[arg].default}"
)
if not shortcut_sig.parameters[arg].default == bot_method_sig.parameters[arg].default:
raise Exception(
f"Default for argument {arg} does not match the default of the Bot method."
)
for kwarg in additional_kwargs:
if kwarg == "reply_to_message_id":
# special case for deprecated argument of Message.reply_*
continue
if not shortcut_sig.parameters[kwarg].kind == inspect.Parameter.KEYWORD_ONLY:
raise Exception(f"Argument {kwarg} must be a positional-only argument!")
return True
async def check_shortcut_call(
shortcut_method: Callable,
bot: ExtBot,
bot_method_name: str,
skip_params: Optional[Iterable[str]] = None,
shortcut_kwargs: Optional[Iterable[str]] = None,
) -> bool:
"""
Checks that a shortcut passes all the existing arguments to the underlying bot method. Use as::
assert await check_shortcut_call(message.reply_text, message.bot, 'send_message')
Args:
shortcut_method: The shortcut method, e.g. `message.reply_text`
bot: The bot
bot_method_name: The bot methods name, e.g. `'send_message'`
skip_params: Parameters that are allowed to be missing, e.g. `['inline_message_id']`
`rate_limit_args` will be skipped by default
shortcut_kwargs: The kwargs passed by the shortcut directly, e.g. ``chat_id``
Returns:
:obj:`bool`
"""
2023-03-25 19:18:04 +01:00
skip_params = set() if not skip_params else set(skip_params)
skip_params.add("rate_limit_args")
2023-03-25 19:18:04 +01:00
shortcut_kwargs = set() if not shortcut_kwargs else set(shortcut_kwargs)
orig_bot_method = getattr(bot, bot_method_name)
bot_signature = inspect.signature(orig_bot_method)
expected_args = set(bot_signature.parameters.keys()) - {"self"} - set(skip_params)
positional_args = {
name for name, param in bot_signature.parameters.items() if param.default == param.empty
}
ignored_args = positional_args | set(shortcut_kwargs)
shortcut_signature = inspect.signature(shortcut_method)
# auto_pagination: Special casing for InlineQuery.answer
# quote: Don't test deprecated "quote" parameter of Message.reply_*
kwargs = {
name: name
for name in shortcut_signature.parameters
if name not in ["auto_pagination", "quote"]
}
if "reply_parameters" in kwargs:
kwargs["reply_parameters"] = ReplyParameters(message_id=1)
# We tested this for a long time, but Bot API 7.0 deprecated it in favor of
# reply_parameters. In the transition phase, both exist in a mutually exclusive
# way. Testing both cases would require a lot of additional code, so we just
# ignore this parameter here until it is removed.
kwargs.pop("reply_to_message_id", None)
expected_args.discard("reply_to_message_id")
async def make_assertion(**kw):
# name == value makes sure that
# a) we receive non-None input for all parameters
# b) we receive the correct input for each kwarg
received_kwargs = {
name
for name, value in kw.items()
if name in ignored_args
or (value == name or (name == "reply_parameters" and value.message_id == 1))
}
if not received_kwargs == expected_args:
raise Exception(
f"{orig_bot_method.__name__} did not receive correct value for the parameters "
f"{expected_args - received_kwargs}"
)
if bot_method_name == "get_file":
# This is here mainly for PassportFile.get_file, which calls .set_credentials on the
# return value
return File(file_id="result", file_unique_id="result")
return True
setattr(bot, bot_method_name, make_assertion)
try:
await shortcut_method(**kwargs)
except Exception as exc:
raise exc
finally:
setattr(bot, bot_method_name, orig_bot_method)
return True
def build_kwargs(
signature: inspect.Signature, default_kwargs, manually_passed_value: Any = DEFAULT_NONE
):
kws = {}
for name, param in signature.parameters.items():
# For required params we need to pass something
if param.default is inspect.Parameter.empty:
# Some special casing
if name == "permissions":
kws[name] = ChatPermissions()
elif name in ["prices", "commands", "errors"]:
kws[name] = []
elif name == "media":
media = InputMediaPhoto("media", parse_mode=manually_passed_value)
if "list" in str(param.annotation).lower():
kws[name] = [media]
else:
kws[name] = media
elif name == "results":
itmc = InputTextMessageContent(
"text",
parse_mode=manually_passed_value,
link_preview_options=LinkPreviewOptions(
is_disabled=manually_passed_value, url=manually_passed_value
),
)
kws[name] = [
InlineQueryResultArticle("id", "title", input_message_content=itmc),
InlineQueryResultCachedPhoto(
"id",
"photo_file_id",
parse_mode=manually_passed_value,
input_message_content=itmc,
),
]
elif name == "ok":
kws["ok"] = False
kws["error_message"] = "error"
elif name == "options":
kws[name] = ["option1", "option2"]
else:
kws[name] = True
# pass values for params that can have defaults only if we don't want to use the
# standard default
elif name in default_kwargs:
if manually_passed_value != DEFAULT_NONE:
if name == "link_preview_options":
kws[name] = LinkPreviewOptions(
is_disabled=manually_passed_value, url=manually_passed_value
)
else:
kws[name] = manually_passed_value
# Some special casing for methods that have "exactly one of the optionals" type args
elif name in ["location", "contact", "venue", "inline_message_id"]:
kws[name] = True
elif name == "until_date":
if manually_passed_value not in [None, DEFAULT_NONE]:
# Europe/Berlin
kws[name] = pytz.timezone("Europe/Berlin").localize(dtm.datetime(2000, 1, 1, 0))
else:
# naive UTC
kws[name] = dtm.datetime(2000, 1, 1, 0)
elif name == "reply_parameters":
kws[name] = telegram.ReplyParameters(
message_id=1,
allow_sending_without_reply=manually_passed_value,
quote_parse_mode=manually_passed_value,
)
return kws
def make_assertion_for_link_preview_options(
expected_defaults_value, lpo, manual_value_expected, manually_passed_value
):
if not lpo:
return
# if no_value_expected:
# # We always expect a value for link_preview_options, because we don't test the
# # case send_message(…, link_preview_options=None). Instead we focus on the more
# # compicated case of send_message(…, link_preview_options=LinkPreviewOptions(arg=None))
if manual_value_expected:
if lpo.get("is_disabled") != manually_passed_value:
pytest.fail(
f"Got value {lpo.get('is_disabled')} for link_preview_options.is_disabled, "
f"expected it to be {manually_passed_value}"
)
if lpo.get("url") != manually_passed_value:
pytest.fail(
f"Got value {lpo.get('url')} for link_preview_options.url, "
f"expected it to be {manually_passed_value}"
)
if expected_defaults_value:
if lpo.get("show_above_text") != expected_defaults_value:
pytest.fail(
f"Got value {lpo.get('show_above_text')} for link_preview_options.show_above_text,"
f" expected it to be {expected_defaults_value}"
)
if manually_passed_value is DEFAULT_NONE and lpo.get("url") != expected_defaults_value:
pytest.fail(
f"Got value {lpo.get('url')} for link_preview_options.url, "
f"expected it to be {expected_defaults_value}"
)
async def make_assertion(
url,
request_data: RequestData,
method_name: str,
kwargs_need_default: list[str],
return_value,
manually_passed_value: Any = DEFAULT_NONE,
expected_defaults_value: Any = DEFAULT_NONE,
*args,
**kwargs,
):
data = request_data.parameters
no_value_expected = (manually_passed_value is None) or (
manually_passed_value is DEFAULT_NONE and expected_defaults_value is None
)
manual_value_expected = (manually_passed_value is not DEFAULT_NONE) and not no_value_expected
default_value_expected = (not manual_value_expected) and (not no_value_expected)
# Check reply_parameters - needs special handling b/c we merge this with the default
# value for `allow_sending_without_reply`
reply_parameters = data.pop("reply_parameters", None)
if reply_parameters:
for param in ["allow_sending_without_reply", "quote_parse_mode"]:
if no_value_expected and param in reply_parameters:
pytest.fail(f"Got value for reply_parameters.{param}, expected it to be absent")
param_value = reply_parameters.get(param)
if manual_value_expected and param_value != manually_passed_value:
pytest.fail(
f"Got value {param_value} for reply_parameters.{param} "
f"instead of {manually_passed_value}"
)
elif default_value_expected and param_value != expected_defaults_value:
pytest.fail(
f"Got value {param_value} for reply_parameters.{param} "
f"instead of {expected_defaults_value}"
)
# Check link_preview_options - needs special handling b/c we merge this with the default
# values specified in `Defaults.link_preview_options`
make_assertion_for_link_preview_options(
expected_defaults_value,
data.get("link_preview_options", None),
manual_value_expected,
manually_passed_value,
)
# Check regular arguments that need defaults
for arg in kwargs_need_default:
if arg == "link_preview_options":
# already handled above
continue
# 'None' should not be passed along to Telegram
if no_value_expected and arg in data:
pytest.fail(f"Got value {data[arg]} for argument {arg}, expected it to be absent")
value = data.get(arg, "`not passed at all`")
if manual_value_expected and value != manually_passed_value:
pytest.fail(f"Got value {value} for argument {arg} instead of {manually_passed_value}")
elif default_value_expected and value != expected_defaults_value:
pytest.fail(
f"Got value {value} for argument {arg} instead of {expected_defaults_value}"
)
# Check InputMedia (parse_mode can have a default)
def check_input_media(m: dict):
parse_mode = m.get("parse_mode")
if no_value_expected and parse_mode is not None:
pytest.fail("InputMedia has non-None parse_mode, expected it to be absent")
elif default_value_expected and parse_mode != expected_defaults_value:
pytest.fail(
f"Got value {parse_mode} for InputMedia.parse_mode instead "
f"of {expected_defaults_value}"
)
elif manual_value_expected and parse_mode != manually_passed_value:
pytest.fail(
f"Got value {parse_mode} for InputMedia.parse_mode instead "
f"of {manually_passed_value}"
)
media = data.pop("media", None)
if media:
if isinstance(media, dict) and isinstance(media.get("type", None), InputMediaType):
check_input_media(media)
else:
for m in media:
check_input_media(m)
# Check InlineQueryResults
results = data.pop("results", [])
for result in results:
if no_value_expected and "parse_mode" in result:
pytest.fail("ILQR has a parse mode, expected it to be absent")
# Here we explicitly use that we only pass ILQRPhoto and ILQRArticle for testing
elif "photo" in result:
parse_mode = result.get("parse_mode")
if manually_passed_value and parse_mode != manually_passed_value:
pytest.fail(
f"Got value {parse_mode} for ILQR.parse_mode instead of "
f"{manually_passed_value}"
)
elif default_value_expected and parse_mode != expected_defaults_value:
pytest.fail(
f"Got value {parse_mode} for ILQR.parse_mode instead of "
f"{expected_defaults_value}"
)
# Here we explicitly use that we only pass InputTextMessageContent for testing
# which has both parse_mode and link_preview_options
imc = result.get("input_message_content")
if not imc:
continue
if no_value_expected and "parse_mode" in imc:
pytest.fail("ILQR.i_m_c has a parse_mode, expected it to be absent")
parse_mode = imc.get("parse_mode")
if manual_value_expected and parse_mode != manually_passed_value:
pytest.fail(
f"Got value {imc.parse_mode} for ILQR.i_m_c.parse_mode "
f"instead of {manual_value_expected}"
)
elif default_value_expected and parse_mode != expected_defaults_value:
pytest.fail(
f"Got value {imc.parse_mode} for ILQR.i_m_c.parse_mode "
f"instead of {expected_defaults_value}"
)
make_assertion_for_link_preview_options(
expected_defaults_value,
imc.get("link_preview_options", None),
manual_value_expected,
manually_passed_value,
)
# Check datetime conversion
until_date = data.pop("until_date", None)
if until_date:
if manual_value_expected and until_date != 946681200:
pytest.fail("Non-naive until_date should have been interpreted as Europe/Berlin.")
if not any((manually_passed_value, expected_defaults_value)) and until_date != 946684800:
pytest.fail("Naive until_date should have been interpreted as UTC")
if default_value_expected and until_date != 946702800:
pytest.fail("Naive until_date should have been interpreted as America/New_York")
if method_name in ["get_file", "get_small_file", "get_big_file"]:
# This is here mainly for PassportFile.get_file, which calls .set_credentials on the
# return value
out = File(file_id="result", file_unique_id="result")
return out.to_dict()
# Otherwise return None by default, as TGObject.de_json/list(None) in [None, []]
# That way we can check what gets passed to Request.post without having to actually
# make a request
# Some methods expect specific output, so we allow to customize that
if isinstance(return_value, TelegramObject):
return return_value.to_dict()
return return_value
async def check_defaults_handling(
method: Callable,
bot: Bot,
return_value=None,
no_default_kwargs: Collection[str] = frozenset(),
) -> bool:
"""
Checks that tg.ext.Defaults are handled correctly.
Args:
method: The shortcut/bot_method
bot: The bot. May be a telegram.Bot or a telegram.ext.ExtBot. In the former case, all
default values will be converted to None.
return_value: Optional. The return value of Bot._post that the method expects. Defaults to
None. get_file is automatically handled. If this is a `TelegramObject`, Bot._post will
return the `to_dict` representation of it.
no_default_kwargs: Optional. A collection of keyword arguments that should not have default
values. Defaults to an empty frozenset.
"""
raw_bot = not isinstance(bot, ExtBot)
get_updates = method.__name__.lower().replace("_", "") == "getupdates"
shortcut_signature = inspect.signature(method)
kwargs_need_default = {
kwarg
for kwarg, value in shortcut_signature.parameters.items()
if isinstance(value.default, DefaultValue)
and not kwarg.endswith("_timeout")
and kwarg not in no_default_kwargs
}
# We tested this for a long time, but Bot API 7.0 deprecated it in favor of
# reply_parameters. In the transition phase, both exist in a mutually exclusive
# way. Testing both cases would require a lot of additional code, so we for now are content
# with the explicit tests that we have inplace for allow_sending_without_reply
kwargs_need_default.discard("allow_sending_without_reply")
if method.__name__.endswith("_media_group"):
# the parse_mode is applied to the first media item, and we test this elsewhere
kwargs_need_default.remove("parse_mode")
defaults_no_custom_defaults = Defaults()
2023-03-25 19:18:04 +01:00
kwargs = {kwarg: "custom_default" for kwarg in inspect.signature(Defaults).parameters}
kwargs["tzinfo"] = pytz.timezone("America/New_York")
kwargs.pop("disable_web_page_preview") # mutually exclusive with link_preview_options
kwargs.pop("quote") # mutually exclusive with do_quote
kwargs["link_preview_options"] = LinkPreviewOptions(
url="custom_default", show_above_text="custom_default"
)
defaults_custom_defaults = Defaults(**kwargs)
expected_return_values = [None, ()] if return_value is None else [return_value]
if method.__name__ in ["get_file", "get_small_file", "get_big_file"]:
expected_return_values = [File(file_id="result", file_unique_id="result")]
request = bot._request[0] if get_updates else bot.request
orig_post = request.post
try:
if raw_bot:
combinations = [(None, None)]
else:
combinations = [
(None, defaults_no_custom_defaults),
("custom_default", defaults_custom_defaults),
]
for expected_defaults_value, defaults in combinations:
if not raw_bot:
bot._defaults = defaults
# 1: test that we get the correct default value, if we don't specify anything
kwargs = build_kwargs(shortcut_signature, kwargs_need_default)
assertion_callback = functools.partial(
make_assertion,
kwargs_need_default=kwargs_need_default,
method_name=method.__name__,
return_value=return_value,
expected_defaults_value=expected_defaults_value,
)
2023-03-25 19:18:04 +01:00
request.post = assertion_callback
assert await method(**kwargs) in expected_return_values
# 2: test that we get the manually passed non-None value
kwargs = build_kwargs(
shortcut_signature, kwargs_need_default, manually_passed_value="non-None-value"
)
assertion_callback = functools.partial(
make_assertion,
manually_passed_value="non-None-value",
kwargs_need_default=kwargs_need_default,
method_name=method.__name__,
return_value=return_value,
expected_defaults_value=expected_defaults_value,
)
2023-03-25 19:18:04 +01:00
request.post = assertion_callback
assert await method(**kwargs) in expected_return_values
# 3: test that we get the manually passed None value
kwargs = build_kwargs(
shortcut_signature, kwargs_need_default, manually_passed_value=None
)
assertion_callback = functools.partial(
make_assertion,
manually_passed_value=None,
kwargs_need_default=kwargs_need_default,
method_name=method.__name__,
return_value=return_value,
expected_defaults_value=expected_defaults_value,
)
2023-03-25 19:18:04 +01:00
request.post = assertion_callback
assert await method(**kwargs) in expected_return_values
except Exception as exc:
raise exc
finally:
2023-03-25 19:18:04 +01:00
request.post = orig_post
if not raw_bot:
bot._defaults = None
return True