Add Bot.do_api_request (#4084)

This commit is contained in:
Bibo-Joshi 2024-02-07 22:35:09 +01:00 committed by GitHub
parent 7e9537ece2
commit 29866e2139
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 349 additions and 17 deletions

View file

@ -41,7 +41,8 @@ keyword_args = [
), ),
( (
" api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments" " api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments"
" to be passed to the Telegram API." " to be passed to the Telegram API. See :meth:`~telegram.Bot.do_api_request` for"
" limitations."
), ),
"", "",
] ]

View file

@ -93,11 +93,12 @@ from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue
from telegram._utils.files import is_local_file, parse_file_input from telegram._utils.files import is_local_file, parse_file_input
from telegram._utils.logging import get_logger from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.strings import to_camel_case
from telegram._utils.types import CorrectOptionID, FileInput, JSONDict, ODVInput, ReplyMarkup from telegram._utils.types import CorrectOptionID, FileInput, JSONDict, ODVInput, ReplyMarkup
from telegram._utils.warnings import warn from telegram._utils.warnings import warn
from telegram._webhookinfo import WebhookInfo from telegram._webhookinfo import WebhookInfo
from telegram.constants import InlineQueryLimit from telegram.constants import InlineQueryLimit
from telegram.error import InvalidToken from telegram.error import EndPointNotFound, InvalidToken
from telegram.request import BaseRequest, RequestData from telegram.request import BaseRequest, RequestData
from telegram.request._httpxrequest import HTTPXRequest from telegram.request._httpxrequest import HTTPXRequest
from telegram.request._requestparameter import RequestParameter from telegram.request._requestparameter import RequestParameter
@ -147,8 +148,8 @@ class Bot(TelegramObject, AsyncContextManager["Bot"]):
Note: Note:
* Most bot methods have the argument ``api_kwargs`` which allows passing arbitrary keywords * Most bot methods have the argument ``api_kwargs`` which allows passing arbitrary keywords
to the Telegram API. This can be used to access new features of the API before they are to the Telegram API. This can be used to access new features of the API before they are
incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for incorporated into PTB. The limitations to this argument are the same as the ones
passing files. described in :meth:`do_api_request`.
* Bots should not be serialized since if you for e.g. change the bots token, then your * Bots should not be serialized since if you for e.g. change the bots token, then your
serialized instance will not reflect that change. Trying to pickle a bot instance will serialized instance will not reflect that change. Trying to pickle a bot instance will
raise :exc:`pickle.PicklingError`. Trying to deepcopy a bot instance will raise raise :exc:`pickle.PicklingError`. Trying to deepcopy a bot instance will raise
@ -762,6 +763,101 @@ class Bot(TelegramObject, AsyncContextManager["Bot"]):
await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown())
self._initialized = False self._initialized = False
@_log
async def do_api_request(
self,
endpoint: str,
api_kwargs: Optional[JSONDict] = None,
return_type: Optional[Type[TelegramObject]] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
connect_timeout: ODVInput[float] = DEFAULT_NONE,
pool_timeout: ODVInput[float] = DEFAULT_NONE,
) -> Any:
"""Do a request to the Telegram API.
This method is here to make it easier to use new API methods that are not yet supported
by this library.
Hint:
Since PTB does not know which arguments are passed to this method, some caution is
necessary in terms of PTBs utility functionalities. In particular
* passing objects of any class defined in the :mod:`telegram` module is supported
* when uploading files, a :class:`telegram.InputFile` must be passed as the value for
the corresponding argument. Passing a file path or file-like object will not work.
File paths will work only in combination with :paramref:`~Bot.local_mode`.
* when uploading files, PTB can still correctly determine that
a special write timeout value should be used instead of the default
:paramref:`telegram.request.HTTPXRequest.write_timeout`.
* insertion of default values specified via :class:`telegram.ext.Defaults` will not
work (only relevant for :class:`telegram.ext.ExtBot`).
* The only exception is :class:`telegram.ext.Defaults.tzinfo`, which will be correctly
applied to :class:`datetime.datetime` objects.
.. versionadded:: NEXT.VERSION
Args:
endpoint (:obj:`str`): The API endpoint to use, e.g. ``getMe`` or ``get_me``.
api_kwargs (:obj:`dict`, optional): The keyword arguments to pass to the API call.
If not specified, no arguments are passed.
return_type (:class:`telegram.TelegramObject`, optional): If specified, the result of
the API call will be deserialized into an instance of this class or tuple of
instances of this class. If not specified, the raw result of the API call will be
returned.
Returns:
The result of the API call. If :paramref:`return_type` is not specified, this is a
:obj:`dict` or :obj:`bool`, otherwise an instance of :paramref:`return_type` or a
tuple of :paramref:`return_type`.
Raises:
:class:`telegram.error.TelegramError`
"""
if hasattr(self, endpoint):
self._warn(
(
f"Please use 'Bot.{endpoint}' instead of "
f"'Bot.do_api_request(\"{endpoint}\", ...)'"
),
PTBDeprecationWarning,
stacklevel=3,
)
camel_case_endpoint = to_camel_case(endpoint)
try:
result = await self._post(
camel_case_endpoint,
api_kwargs=api_kwargs,
read_timeout=read_timeout,
write_timeout=write_timeout,
connect_timeout=connect_timeout,
pool_timeout=pool_timeout,
)
except InvalidToken as exc:
# TG returns 404 Not found for
# 1) malformed tokens
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
# 2) is relevant only for Bot.do_api_request, that's why we have special handling for
# that here rather than in BaseRequest._request_wrapper
if self._initialized:
raise EndPointNotFound(
f"Endpoint '{camel_case_endpoint}' not found in Bot API"
) from exc
raise InvalidToken(
"Either the bot token was rejected by Telegram or the endpoint "
f"'{camel_case_endpoint}' does not exist."
) from exc
if return_type is None or isinstance(result, bool):
return result
if isinstance(result, list):
return return_type.de_list(result, self)
return return_type.de_json(result, self)
@_log @_log
async def get_me( async def get_me(
self, self,

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2023
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains a helper functions related to string manipulation.
Warning:
Contents of this module are intended to be used internally by the library and *not* by the
user. Changes to this module are not considered breaking changes and may not be documented in
the changelog.
"""
def to_camel_case(snake_str: str) -> str:
"""Converts a snake_case string to camelCase.
Args:
snake_str (:obj:`str`): The string to convert.
Returns:
:obj:`str`: The converted string.
"""
components = snake_str.split("_")
return components[0] + "".join(x.title() for x in components[1:])

View file

@ -26,6 +26,7 @@ __all__ = (
"BadRequest", "BadRequest",
"ChatMigrated", "ChatMigrated",
"Conflict", "Conflict",
"EndPointNotFound",
"Forbidden", "Forbidden",
"InvalidToken", "InvalidToken",
"NetworkError", "NetworkError",
@ -133,6 +134,16 @@ class InvalidToken(TelegramError):
super().__init__("Invalid token" if message is None else message) super().__init__("Invalid token" if message is None else message)
class EndPointNotFound(TelegramError):
"""Raised when the requested endpoint is not found. Only relevant for
:meth:`telegram.Bot.do_api_request`.
.. versionadded:: NEXT.VERSION
"""
__slots__ = ()
class NetworkError(TelegramError): class NetworkError(TelegramError):
"""Base class for exceptions due to networking errors. """Base class for exceptions due to networking errors.

View file

@ -73,6 +73,7 @@ from telegram import (
SentWebAppMessage, SentWebAppMessage,
Sticker, Sticker,
StickerSet, StickerSet,
TelegramObject,
Update, Update,
User, User,
UserProfilePhotos, UserProfilePhotos,
@ -644,6 +645,28 @@ class ExtBot(Bot, Generic[RLARGS]):
return res return res
async def do_api_request(
self,
endpoint: str,
api_kwargs: Optional[JSONDict] = None,
return_type: Optional[Type[TelegramObject]] = None,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
connect_timeout: ODVInput[float] = DEFAULT_NONE,
pool_timeout: ODVInput[float] = DEFAULT_NONE,
rate_limit_args: Optional[RLARGS] = None,
) -> Any:
return await super().do_api_request(
endpoint=endpoint,
api_kwargs=self._merge_api_rl_kwargs(api_kwargs, rate_limit_args),
return_type=return_type,
read_timeout=read_timeout,
write_timeout=write_timeout,
connect_timeout=connect_timeout,
pool_timeout=pool_timeout,
)
async def stop_poll( async def stop_poll(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],

View file

@ -372,7 +372,7 @@ class BaseRequest(
# TG returns 404 Not found for # TG returns 404 Not found for
# 1) malformed tokens # 1) malformed tokens
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod # 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
# We can basically rule out 2) since we don't let users make requests manually # 2) is relevant only for Bot.do_api_request, where we have special handing for it.
# TG returns 401 Unauthorized for correctly formatted tokens that are not valid # TG returns 401 Unauthorized for correctly formatted tokens that are not valid
raise InvalidToken(message) raise InvalidToken(message)
if code == HTTPStatus.BAD_REQUEST: # 400 if code == HTTPStatus.BAD_REQUEST: # 400

View file

@ -27,6 +27,8 @@ import socket
import time import time
from collections import defaultdict from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO
from typing import Tuple
import httpx import httpx
import pytest import pytest
@ -50,6 +52,7 @@ from telegram import (
InlineQueryResultsButton, InlineQueryResultsButton,
InlineQueryResultVoice, InlineQueryResultVoice,
InputFile, InputFile,
InputMediaDocument,
InputMessageContent, InputMessageContent,
InputTextMessageContent, InputTextMessageContent,
LabeledPrice, LabeledPrice,
@ -69,6 +72,7 @@ from telegram import (
) )
from telegram._utils.datetime import UTC, from_timestamp, to_timestamp from telegram._utils.datetime import UTC, from_timestamp, to_timestamp
from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.strings import to_camel_case
from telegram.constants import ( from telegram.constants import (
ChatAction, ChatAction,
InlineQueryLimit, InlineQueryLimit,
@ -76,7 +80,7 @@ from telegram.constants import (
MenuButtonType, MenuButtonType,
ParseMode, ParseMode,
) )
from telegram.error import BadRequest, InvalidToken, NetworkError from telegram.error import BadRequest, EndPointNotFound, InvalidToken, NetworkError
from telegram.ext import ExtBot, InvalidCallbackData from telegram.ext import ExtBot, InvalidCallbackData
from telegram.helpers import escape_markdown from telegram.helpers import escape_markdown
from telegram.request import BaseRequest, HTTPXRequest, RequestData from telegram.request import BaseRequest, HTTPXRequest, RequestData
@ -90,14 +94,6 @@ from tests.auxil.pytest_classes import PytestBot, PytestExtBot, make_bot
from tests.auxil.slots import mro_slots from tests.auxil.slots import mro_slots
def to_camel_case(snake_str):
"""https://stackoverflow.com/a/19053800"""
components = snake_str.split("_")
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + "".join(x.title() for x in components[1:])
@pytest.fixture() @pytest.fixture()
async def message(bot, chat_id): # mostly used in tests for edit_message async def message(bot, chat_id): # mostly used in tests for edit_message
out = await bot.send_message( out = await bot.send_message(
@ -145,7 +141,7 @@ xfail = pytest.mark.xfail(
) )
def bot_methods(ext_bot=True, include_camel_case=False): def bot_methods(ext_bot=True, include_camel_case=False, include_do_api_request=False):
arg_values = [] arg_values = []
ids = [] ids = []
non_api_methods = [ non_api_methods = [
@ -160,6 +156,9 @@ def bot_methods(ext_bot=True, include_camel_case=False):
"shutdown", "shutdown",
"insert_callback_data", "insert_callback_data",
] ]
if not include_do_api_request:
non_api_methods.append("do_api_request")
classes = (Bot, ExtBot) if ext_bot else (Bot,) classes = (Bot, ExtBot) if ext_bot else (Bot,)
for cls in classes: for cls in classes:
for name, attribute in inspect.getmembers(cls, predicate=inspect.isfunction): for name, attribute in inspect.getmembers(cls, predicate=inspect.isfunction):
@ -420,13 +419,13 @@ class TestBotWithoutRequest:
assert camel_case_function is not False, f"{camel_case_name} not found" assert camel_case_function is not False, f"{camel_case_name} not found"
assert camel_case_function is bot_method, f"{camel_case_name} is not {bot_method}" assert camel_case_function is bot_method, f"{camel_case_name} is not {bot_method}"
@bot_methods() @bot_methods(include_do_api_request=True)
def test_coroutine_functions(self, bot_class, bot_method_name, bot_method): def test_coroutine_functions(self, bot_class, bot_method_name, bot_method):
"""Check that all bot methods are defined as async def ...""" """Check that all bot methods are defined as async def ..."""
meth = getattr(bot_method, "__wrapped__", bot_method) # to unwrap the @_log decorator meth = getattr(bot_method, "__wrapped__", bot_method) # to unwrap the @_log decorator
assert inspect.iscoroutinefunction(meth), f"{bot_method_name} must be a coroutine function" assert inspect.iscoroutinefunction(meth), f"{bot_method_name} must be a coroutine function"
@bot_methods() @bot_methods(include_do_api_request=True)
def test_api_kwargs_and_timeouts_present(self, bot_class, bot_method_name, bot_method): def test_api_kwargs_and_timeouts_present(self, bot_class, bot_method_name, bot_method):
"""Check that all bot methods have `api_kwargs` and timeout params.""" """Check that all bot methods have `api_kwargs` and timeout params."""
param_names = inspect.signature(bot_method).parameters.keys() param_names = inspect.signature(bot_method).parameters.keys()
@ -1795,6 +1794,75 @@ class TestBotWithoutRequest:
bot.get_my_name(), bot.get_my_name("en"), bot.get_my_name("de") bot.get_my_name(), bot.get_my_name("en"), bot.get_my_name("de")
) == 3 * [BotName(default_name)] ) == 3 * [BotName(default_name)]
async def test_do_api_request_camel_case_conversion(self, bot, monkeypatch):
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
return url.endswith("camelCase")
monkeypatch.setattr(bot.request, "post", make_assertion)
assert await bot.do_api_request("camel_case")
async def test_do_api_request_media_write_timeout(self, bot, chat_id, monkeypatch):
test_flag = None
class CustomRequest(BaseRequest):
async def initialize(self_) -> None:
pass
async def shutdown(self_) -> None:
pass
async def do_request(self_, *args, **kwargs) -> Tuple[int, bytes]:
nonlocal test_flag
test_flag = (
kwargs.get("read_timeout"),
kwargs.get("connect_timeout"),
kwargs.get("write_timeout"),
kwargs.get("pool_timeout"),
)
return HTTPStatus.OK, b'{"ok": "True", "result": {}}'
custom_request = CustomRequest()
bot = Bot(bot.token, request=custom_request)
await bot.do_api_request(
"send_document",
api_kwargs={
"chat_id": chat_id,
"caption": "test_caption",
"document": InputFile(data_file("telegram.png").open("rb")),
},
)
assert test_flag == (
DEFAULT_NONE,
DEFAULT_NONE,
20,
DEFAULT_NONE,
)
async def test_do_api_request_default_timezone(self, tz_bot, monkeypatch):
until = dtm.datetime(2020, 1, 11, 16, 13)
until_timestamp = to_timestamp(until, tzinfo=tz_bot.defaults.tzinfo)
async def make_assertion(url, request_data: RequestData, *args, **kwargs):
data = request_data.parameters
chat_id = data["chat_id"] == 2
user_id = data["user_id"] == 32
until_date = data.get("until_date", until_timestamp) == until_timestamp
return chat_id and user_id and until_date
monkeypatch.setattr(tz_bot.request, "post", make_assertion)
assert await tz_bot.do_api_request(
"banChatMember", api_kwargs={"chat_id": 2, "user_id": 32}
)
assert await tz_bot.do_api_request(
"banChatMember", api_kwargs={"chat_id": 2, "user_id": 32, "until_date": until}
)
assert await tz_bot.do_api_request(
"banChatMember",
api_kwargs={"chat_id": 2, "user_id": 32, "until_date": until_timestamp},
)
class TestBotWithRequest: class TestBotWithRequest:
""" """
@ -3500,3 +3568,94 @@ class TestBotWithRequest:
bot.get_my_short_description("en"), bot.get_my_short_description("en"),
bot.get_my_short_description("de"), bot.get_my_short_description("de"),
) == 3 * [BotShortDescription("")] ) == 3 * [BotShortDescription("")]
@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
async def test_do_api_request_warning_known_method(self, bot, bot_class):
with pytest.warns(PTBDeprecationWarning, match="Please use 'Bot.get_me'") as record:
await bot_class(bot.token).do_api_request("get_me")
assert record[0].filename == __file__, "Wrong stack level!"
async def test_do_api_request_unknown_method(self, bot):
with pytest.raises(EndPointNotFound, match="'unknownEndpoint' not found"):
await bot.do_api_request("unknown_endpoint")
async def test_do_api_request_invalid_token(self, bot):
# we do not initialize the bot here on purpose b/c that's the case were we actually
# do not know for sure if the token is invalid or the method was not found
with pytest.raises(
InvalidToken, match="token was rejected by Telegram or the endpoint 'getMe'"
):
await Bot("invalid_token").do_api_request("get_me")
# same test, but with a valid token bot and unknown endpoint
with pytest.raises(
InvalidToken, match="token was rejected by Telegram or the endpoint 'unknownEndpoint'"
):
await Bot(bot.token).do_api_request("unknown_endpoint")
@pytest.mark.parametrize("return_type", [Message, None])
async def test_do_api_request_basic_and_files(self, bot, chat_id, return_type):
result = await bot.do_api_request(
"send_document",
api_kwargs={
"chat_id": chat_id,
"caption": "test_caption",
"document": InputFile(data_file("telegram.png").open("rb")),
},
return_type=return_type,
)
if return_type is None:
assert isinstance(result, dict)
result = Message.de_json(result, bot)
assert isinstance(result, Message)
assert result.chat_id == int(chat_id)
assert result.caption == "test_caption"
out = BytesIO()
await (await result.document.get_file()).download_to_memory(out)
out.seek(0)
assert out.read() == data_file("telegram.png").open("rb").read()
assert result.document.file_name == "telegram.png"
@pytest.mark.parametrize("return_type", [Message, None])
async def test_do_api_request_list_return_type(self, bot, chat_id, return_type):
result = await bot.do_api_request(
"send_media_group",
api_kwargs={
"chat_id": chat_id,
"media": [
InputMediaDocument(
InputFile(
data_file("text_file.txt").open("rb"),
attach=True,
)
),
InputMediaDocument(
InputFile(
data_file("local_file.txt").open("rb"),
attach=True,
)
),
],
},
return_type=return_type,
)
if return_type is None:
assert isinstance(result, list)
for entry in result:
assert isinstance(entry, dict)
result = Message.de_list(result, bot)
for message, file_name in zip(result, ("text_file.txt", "local_file.txt")):
assert isinstance(message, Message)
assert message.chat_id == int(chat_id)
out = BytesIO()
await (await message.document.get_file()).download_to_memory(out)
out.seek(0)
assert out.read() == data_file(file_name).open("rb").read()
assert message.document.file_name == file_name
@pytest.mark.parametrize("return_type", [Message, None])
async def test_do_api_request_bool_return_type(self, bot, chat_id, return_type):
assert await bot.do_api_request("delete_my_commands", return_type=return_type) is True

View file

@ -25,6 +25,7 @@ from telegram.error import (
BadRequest, BadRequest,
ChatMigrated, ChatMigrated,
Conflict, Conflict,
EndPointNotFound,
Forbidden, Forbidden,
InvalidToken, InvalidToken,
NetworkError, NetworkError,
@ -113,6 +114,7 @@ class TestErrors:
(Conflict("test message"), ["message"]), (Conflict("test message"), ["message"]),
(PassportDecryptionError("test message"), ["message"]), (PassportDecryptionError("test message"), ["message"]),
(InvalidCallbackData("test data"), ["callback_data"]), (InvalidCallbackData("test data"), ["callback_data"]),
(EndPointNotFound("endPoint"), ["message"]),
], ],
) )
def test_errors_pickling(self, exception, attributes): def test_errors_pickling(self, exception, attributes):
@ -138,6 +140,7 @@ class TestErrors:
(Conflict("test message")), (Conflict("test message")),
(PassportDecryptionError("test message")), (PassportDecryptionError("test message")),
(InvalidCallbackData("test data")), (InvalidCallbackData("test data")),
(EndPointNotFound("test message")),
], ],
) )
def test_slot_behaviour(self, inst): def test_slot_behaviour(self, inst):
@ -170,6 +173,7 @@ class TestErrors:
Conflict, Conflict,
PassportDecryptionError, PassportDecryptionError,
InvalidCallbackData, InvalidCallbackData,
EndPointNotFound,
}, },
NetworkError: {BadRequest, TimedOut}, NetworkError: {BadRequest, TimedOut},
} }