From 616b0b55ef86453b5f6bb680439f1dfe9aed054a Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:27:30 +0100 Subject: [PATCH] Add `ApplicationBuilder.(get_updates_)socket_options` (#3943) --- telegram/_utils/types.py | 6 +++ telegram/ext/_applicationbuilder.py | 54 +++++++++++++++++++++- telegram/request/_httpxrequest.py | 10 +---- tests/ext/test_applicationbuilder.py | 67 ++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 9 deletions(-) diff --git a/telegram/_utils/types.py b/telegram/_utils/types.py index 670662f4c..046a72dfc 100644 --- a/telegram/_utils/types.py +++ b/telegram/_utils/types.py @@ -95,3 +95,9 @@ HTTPVersion = Literal["1.1", "2.0", "2"] CorrectOptionID = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] MarkdownVersion = Literal[1, 2] + +SocketOpt = Union[ + Tuple[int, int, int], + Tuple[int, int, Union[bytes, bytearray]], + Tuple[int, int, None, int], +] diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index 5a5ac0402..05806f233 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -23,6 +23,7 @@ from typing import ( TYPE_CHECKING, Any, Callable, + Collection, Coroutine, Dict, Generic, @@ -36,7 +37,7 @@ import httpx from telegram._bot import Bot from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue -from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput +from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput, SocketOpt from telegram._utils.warnings import warn from telegram.ext._application import Application from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor @@ -71,6 +72,7 @@ _BOT_CHECKS = [ ("get_updates_request", "get_updates_request instance"), ("connection_pool_size", "connection_pool_size"), ("proxy", "proxy"), + ("socket_options", "socket_options"), ("pool_timeout", "pool_timeout"), ("connect_timeout", "connect_timeout"), ("read_timeout", "read_timeout"), @@ -78,6 +80,7 @@ _BOT_CHECKS = [ ("http_version", "http_version"), ("get_updates_connection_pool_size", "get_updates_connection_pool_size"), ("get_updates_proxy", "get_updates_proxy"), + ("get_updates_socket_options", "get_updates_socket_options"), ("get_updates_pool_timeout", "get_updates_pool_timeout"), ("get_updates_connect_timeout", "get_updates_connect_timeout"), ("get_updates_read_timeout", "get_updates_read_timeout"), @@ -143,6 +146,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): "_get_updates_proxy", "_get_updates_read_timeout", "_get_updates_request", + "_get_updates_socket_options", "_get_updates_write_timeout", "_get_updates_http_version", "_job_queue", @@ -157,6 +161,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): "_rate_limiter", "_read_timeout", "_request", + "_socket_options", "_token", "_update_queue", "_updater", @@ -171,6 +176,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._base_file_url: DVType[str] = DefaultValue("https://api.telegram.org/file/bot") self._connection_pool_size: DVInput[int] = DEFAULT_NONE self._proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE + self._socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE self._connect_timeout: ODVInput[float] = DEFAULT_NONE self._read_timeout: ODVInput[float] = DEFAULT_NONE self._write_timeout: ODVInput[float] = DEFAULT_NONE @@ -178,6 +184,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._request: DVInput[BaseRequest] = DEFAULT_NONE self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE self._get_updates_proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE + self._get_updates_socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE @@ -219,6 +226,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): return getattr(self, f"{prefix}request") proxy = DefaultValue.get_value(getattr(self, f"{prefix}proxy")) + socket_options = DefaultValue.get_value(getattr(self, f"{prefix}socket_options")) if get_updates: connection_pool_size = ( DefaultValue.get_value(getattr(self, f"{prefix}connection_pool_size")) or 1 @@ -245,6 +253,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): connection_pool_size=connection_pool_size, proxy=proxy, http_version=http_version, # type: ignore[arg-type] + socket_options=socket_options, **effective_timeouts, ) @@ -426,6 +435,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): if not isinstance(getattr(self, f"_{prefix}proxy"), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, "proxy")) + if not isinstance(getattr(self, f"_{prefix}socket_options"), DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format(name, "socket_options")) + if not isinstance(getattr(self, f"_{prefix}http_version"), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, "http_version")) @@ -531,6 +543,25 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._proxy = proxy return self + def socket_options(self: BuilderType, socket_options: Collection[SocketOpt]) -> BuilderType: + """Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options` + parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`. + + .. seealso:: :meth:`get_updates_socket_options` + + .. versionadded:: NEXT.VERSION + + Args: + socket_options (Collection[:obj:`tuple`], optional): Socket options. See + :paramref:`telegram.request.HTTPXRequest.socket_options` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ + self._request_param_check(name="socket_options", get_updates=False) + self._socket_options = socket_options + return self + def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType: """Sets the connection attempt timeout for the :paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of @@ -726,6 +757,27 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): self._get_updates_proxy = get_updates_proxy return self + def get_updates_socket_options( + self: BuilderType, get_updates_socket_options: Collection[SocketOpt] + ) -> BuilderType: + """Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options` + parameter of :paramref:`telegram.Bot.get_updates_request`. Defaults to :obj:`None`. + + .. seealso:: :meth:`socket_options` + + .. versionadded:: NEXT.VERSION + + Args: + get_updates_socket_options (Collection[:obj:`tuple`], optional): Socket options. See + :paramref:`telegram.request.HTTPXRequest.socket_options` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ + self._request_param_check(name="socket_options", get_updates=True) + self._get_updates_socket_options = get_updates_socket_options + return self + def get_updates_connect_timeout( self: BuilderType, get_updates_connect_timeout: Optional[float] ) -> BuilderType: diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 1560b8981..b4f7fd50d 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -23,7 +23,7 @@ import httpx from telegram._utils.defaultvalue import DefaultValue from telegram._utils.logging import get_logger -from telegram._utils.types import HTTPVersion, ODVInput +from telegram._utils.types import HTTPVersion, ODVInput, SocketOpt from telegram._utils.warnings import warn from telegram.error import NetworkError, TimedOut from telegram.request._baserequest import BaseRequest @@ -37,12 +37,6 @@ from telegram.warnings import PTBDeprecationWarning _LOGGER = get_logger(__name__, "HTTPXRequest") -_SocketOpt = Union[ - Tuple[int, int, int], - Tuple[int, int, Union[bytes, bytearray]], - Tuple[int, int, None, int], -] - class HTTPXRequest(BaseRequest): """Implementation of :class:`~telegram.request.BaseRequest` using the library @@ -132,7 +126,7 @@ class HTTPXRequest(BaseRequest): connect_timeout: Optional[float] = 5.0, pool_timeout: Optional[float] = 1.0, http_version: HTTPVersion = "1.1", - socket_options: Optional[Collection[_SocketOpt]] = None, + socket_options: Optional[Collection[SocketOpt]] = None, proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None, ): if proxy_url is not None and proxy is not None: diff --git a/tests/ext/test_applicationbuilder.py b/tests/ext/test_applicationbuilder.py index 0ac7ffca3..15e328d14 100644 --- a/tests/ext/test_applicationbuilder.py +++ b/tests/ext/test_applicationbuilder.py @@ -17,11 +17,13 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio +import inspect from dataclasses import dataclass import httpx import pytest +from telegram import Bot from telegram.ext import ( AIORateLimiter, Application, @@ -65,6 +67,34 @@ class TestApplicationBuilder: assert getattr(builder, attr, "err") != "err", f"got extra slot '{attr}'" assert len(mro_slots(builder)) == len(set(mro_slots(builder))), "duplicate slot" + @pytest.mark.parametrize("get_updates", [True, False]) + def test_all_methods_request(self, builder, get_updates): + arguments = inspect.signature(HTTPXRequest.__init__).parameters.keys() + prefix = "get_updates_" if get_updates else "" + for argument in arguments: + if argument == "self": + continue + assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}" + + @pytest.mark.parametrize("bot_class", [Bot, ExtBot]) + def test_all_methods_bot(self, builder, bot_class): + arguments = inspect.signature(bot_class.__init__).parameters.keys() + for argument in arguments: + if argument == "self": + continue + if argument == "private_key_password": + argument = "private_key" # noqa: PLW2901 + assert hasattr(builder, argument), f"missing method {argument}" + + def test_all_methods_application(self, builder): + arguments = inspect.signature(Application.__init__).parameters.keys() + for argument in arguments: + if argument == "self": + continue + if argument == "update_processor": + argument = "concurrent_updates" # noqa: PLW2901 + assert hasattr(builder, argument), f"missing method {argument}" + def test_job_queue_init_exception(self, monkeypatch): def init_raises_runtime_error(*args, **kwargs): raise RuntimeError("RuntimeError") @@ -172,6 +202,7 @@ class TestApplicationBuilder: "write_timeout", "proxy", "proxy_url", + "socket_options", "bot", "updater", "http_version", @@ -201,6 +232,7 @@ class TestApplicationBuilder: "get_updates_write_timeout", "get_updates_proxy", "get_updates_proxy_url", + "get_updates_socket_options", "get_updates_http_version", "bot", "updater", @@ -231,6 +263,7 @@ class TestApplicationBuilder: "get_updates_write_timeout", "get_updates_proxy_url", "get_updates_proxy", + "get_updates_socket_options", "get_updates_http_version", "connection_pool_size", "connect_timeout", @@ -239,6 +272,7 @@ class TestApplicationBuilder: "write_timeout", "proxy", "proxy_url", + "socket_options", "http_version", "bot", "update_queue", @@ -273,6 +307,7 @@ class TestApplicationBuilder: "get_updates_write_timeout", "get_updates_proxy", "get_updates_proxy_url", + "get_updates_socket_options", "get_updates_http_version", "connection_pool_size", "connect_timeout", @@ -281,6 +316,7 @@ class TestApplicationBuilder: "write_timeout", "proxy", "proxy_url", + "socket_options", "bot", "http_version", ] @@ -306,6 +342,7 @@ class TestApplicationBuilder: def test_all_bot_args_custom( self, builder, bot, monkeypatch, proxy_method, get_updates_proxy_method ): + # Only socket_options is tested in a standalone test, since that's easier defaults = Defaults() request = HTTPXRequest() get_updates_request = HTTPXRequest() @@ -379,6 +416,36 @@ class TestApplicationBuilder: assert client.http1 is True assert client.http2 is False + def test_custom_socket_options(self, builder, monkeypatch, bot): + httpx_request_kwargs = [] + httpx_request_init = HTTPXRequest.__init__ + + def init_transport(*args, **kwargs): + nonlocal httpx_request_kwargs + # This is called once for request and once for get_updates_request, so we make + # it a list + httpx_request_kwargs.append(kwargs.copy()) + httpx_request_init(*args, **kwargs) + + monkeypatch.setattr(HTTPXRequest, "__init__", init_transport) + + builder.token(bot.token).build() + assert httpx_request_kwargs[0].get("socket_options") is None + assert httpx_request_kwargs[1].get("socket_options") is None + + httpx_request_kwargs = [] + ApplicationBuilder().token(bot.token).socket_options(((1, 2, 3),)).connection_pool_size( + "request" + ).get_updates_socket_options(((4, 5, 6),)).get_updates_connection_pool_size( + "get_updates" + ).build() + + for kwargs in httpx_request_kwargs: + if kwargs.get("connection_pool_size") == "request": + assert kwargs.get("socket_options") == ((1, 2, 3),) + else: + assert kwargs.get("socket_options") == ((4, 5, 6),) + def test_custom_application_class(self, bot, builder): class CustomApplication(Application): def __init__(self, arg, **kwargs):