Add ApplicationBuilder.(get_updates_)socket_options (#3943)

This commit is contained in:
Bibo-Joshi 2023-10-31 16:27:30 +01:00 committed by GitHub
parent c71612ffae
commit 616b0b55ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 128 additions and 9 deletions

View file

@ -95,3 +95,9 @@ HTTPVersion = Literal["1.1", "2.0", "2"]
CorrectOptionID = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
MarkdownVersion = Literal[1, 2]
SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]

View file

@ -23,6 +23,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Coroutine,
Dict,
Generic,
@ -36,7 +37,7 @@ import httpx
from telegram._bot import Bot
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.ext._application import Application
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
@ -71,6 +72,7 @@ _BOT_CHECKS = [
("get_updates_request", "get_updates_request instance"),
("connection_pool_size", "connection_pool_size"),
("proxy", "proxy"),
("socket_options", "socket_options"),
("pool_timeout", "pool_timeout"),
("connect_timeout", "connect_timeout"),
("read_timeout", "read_timeout"),
@ -78,6 +80,7 @@ _BOT_CHECKS = [
("http_version", "http_version"),
("get_updates_connection_pool_size", "get_updates_connection_pool_size"),
("get_updates_proxy", "get_updates_proxy"),
("get_updates_socket_options", "get_updates_socket_options"),
("get_updates_pool_timeout", "get_updates_pool_timeout"),
("get_updates_connect_timeout", "get_updates_connect_timeout"),
("get_updates_read_timeout", "get_updates_read_timeout"),
@ -143,6 +146,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_get_updates_proxy",
"_get_updates_read_timeout",
"_get_updates_request",
"_get_updates_socket_options",
"_get_updates_write_timeout",
"_get_updates_http_version",
"_job_queue",
@ -157,6 +161,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_rate_limiter",
"_read_timeout",
"_request",
"_socket_options",
"_token",
"_update_queue",
"_updater",
@ -171,6 +176,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._base_file_url: DVType[str] = DefaultValue("https://api.telegram.org/file/bot")
self._connection_pool_size: DVInput[int] = DEFAULT_NONE
self._proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._connect_timeout: ODVInput[float] = DEFAULT_NONE
self._read_timeout: ODVInput[float] = DEFAULT_NONE
self._write_timeout: ODVInput[float] = DEFAULT_NONE
@ -178,6 +184,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._request: DVInput[BaseRequest] = DEFAULT_NONE
self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE
self._get_updates_proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._get_updates_socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE
@ -219,6 +226,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
return getattr(self, f"{prefix}request")
proxy = DefaultValue.get_value(getattr(self, f"{prefix}proxy"))
socket_options = DefaultValue.get_value(getattr(self, f"{prefix}socket_options"))
if get_updates:
connection_pool_size = (
DefaultValue.get_value(getattr(self, f"{prefix}connection_pool_size")) or 1
@ -245,6 +253,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
connection_pool_size=connection_pool_size,
proxy=proxy,
http_version=http_version, # type: ignore[arg-type]
socket_options=socket_options,
**effective_timeouts,
)
@ -426,6 +435,9 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
if not isinstance(getattr(self, f"_{prefix}proxy"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "proxy"))
if not isinstance(getattr(self, f"_{prefix}socket_options"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "socket_options"))
if not isinstance(getattr(self, f"_{prefix}http_version"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "http_version"))
@ -531,6 +543,25 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._proxy = proxy
return self
def socket_options(self: BuilderType, socket_options: Collection[SocketOpt]) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`.
.. seealso:: :meth:`get_updates_socket_options`
.. versionadded:: NEXT.VERSION
Args:
socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=False)
self._socket_options = socket_options
return self
def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType:
"""Sets the connection attempt timeout for the
:paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of
@ -726,6 +757,27 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
self._get_updates_proxy = get_updates_proxy
return self
def get_updates_socket_options(
self: BuilderType, get_updates_socket_options: Collection[SocketOpt]
) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :paramref:`telegram.Bot.get_updates_request`. Defaults to :obj:`None`.
.. seealso:: :meth:`socket_options`
.. versionadded:: NEXT.VERSION
Args:
get_updates_socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=True)
self._get_updates_socket_options = get_updates_socket_options
return self
def get_updates_connect_timeout(
self: BuilderType, get_updates_connect_timeout: Optional[float]
) -> BuilderType:

View file

@ -23,7 +23,7 @@ import httpx
from telegram._utils.defaultvalue import DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.types import HTTPVersion, ODVInput
from telegram._utils.types import HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.error import NetworkError, TimedOut
from telegram.request._baserequest import BaseRequest
@ -37,12 +37,6 @@ from telegram.warnings import PTBDeprecationWarning
_LOGGER = get_logger(__name__, "HTTPXRequest")
_SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]
class HTTPXRequest(BaseRequest):
"""Implementation of :class:`~telegram.request.BaseRequest` using the library
@ -132,7 +126,7 @@ class HTTPXRequest(BaseRequest):
connect_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
http_version: HTTPVersion = "1.1",
socket_options: Optional[Collection[_SocketOpt]] = None,
socket_options: Optional[Collection[SocketOpt]] = None,
proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None,
):
if proxy_url is not None and proxy is not None:

View file

@ -17,11 +17,13 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import inspect
from dataclasses import dataclass
import httpx
import pytest
from telegram import Bot
from telegram.ext import (
AIORateLimiter,
Application,
@ -65,6 +67,34 @@ class TestApplicationBuilder:
assert getattr(builder, attr, "err") != "err", f"got extra slot '{attr}'"
assert len(mro_slots(builder)) == len(set(mro_slots(builder))), "duplicate slot"
@pytest.mark.parametrize("get_updates", [True, False])
def test_all_methods_request(self, builder, get_updates):
arguments = inspect.signature(HTTPXRequest.__init__).parameters.keys()
prefix = "get_updates_" if get_updates else ""
for argument in arguments:
if argument == "self":
continue
assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}"
@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
def test_all_methods_bot(self, builder, bot_class):
arguments = inspect.signature(bot_class.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "private_key_password":
argument = "private_key" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"
def test_all_methods_application(self, builder):
arguments = inspect.signature(Application.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "update_processor":
argument = "concurrent_updates" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"
def test_job_queue_init_exception(self, monkeypatch):
def init_raises_runtime_error(*args, **kwargs):
raise RuntimeError("RuntimeError")
@ -172,6 +202,7 @@ class TestApplicationBuilder:
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"updater",
"http_version",
@ -201,6 +232,7 @@ class TestApplicationBuilder:
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"bot",
"updater",
@ -231,6 +263,7 @@ class TestApplicationBuilder:
"get_updates_write_timeout",
"get_updates_proxy_url",
"get_updates_proxy",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
@ -239,6 +272,7 @@ class TestApplicationBuilder:
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"http_version",
"bot",
"update_queue",
@ -273,6 +307,7 @@ class TestApplicationBuilder:
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
@ -281,6 +316,7 @@ class TestApplicationBuilder:
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"http_version",
]
@ -306,6 +342,7 @@ class TestApplicationBuilder:
def test_all_bot_args_custom(
self, builder, bot, monkeypatch, proxy_method, get_updates_proxy_method
):
# Only socket_options is tested in a standalone test, since that's easier
defaults = Defaults()
request = HTTPXRequest()
get_updates_request = HTTPXRequest()
@ -379,6 +416,36 @@ class TestApplicationBuilder:
assert client.http1 is True
assert client.http2 is False
def test_custom_socket_options(self, builder, monkeypatch, bot):
httpx_request_kwargs = []
httpx_request_init = HTTPXRequest.__init__
def init_transport(*args, **kwargs):
nonlocal httpx_request_kwargs
# This is called once for request and once for get_updates_request, so we make
# it a list
httpx_request_kwargs.append(kwargs.copy())
httpx_request_init(*args, **kwargs)
monkeypatch.setattr(HTTPXRequest, "__init__", init_transport)
builder.token(bot.token).build()
assert httpx_request_kwargs[0].get("socket_options") is None
assert httpx_request_kwargs[1].get("socket_options") is None
httpx_request_kwargs = []
ApplicationBuilder().token(bot.token).socket_options(((1, 2, 3),)).connection_pool_size(
"request"
).get_updates_socket_options(((4, 5, 6),)).get_updates_connection_pool_size(
"get_updates"
).build()
for kwargs in httpx_request_kwargs:
if kwargs.get("connection_pool_size") == "request":
assert kwargs.get("socket_options") == ((1, 2, 3),)
else:
assert kwargs.get("socket_options") == ((4, 5, 6),)
def test_custom_application_class(self, bot, builder):
class CustomApplication(Application):
def __init__(self, arg, **kwargs):