From 2345bfbb5337f6e5d63b9d56bbf3fa75646c1b35 Mon Sep 17 00:00:00 2001 From: Poolitzer Date: Thu, 14 Dec 2023 21:37:00 +0100 Subject: [PATCH] Add Support for Unix Sockets to `Updater.start_webhook` (#3986) Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> --- telegram/_utils/defaultvalue.py | 16 ++++ telegram/ext/_application.py | 24 ++++- telegram/ext/_updater.py | 45 +++++++-- telegram/ext/_utils/webhookhandler.py | 27 +++++- tests/auxil/networking.py | 8 +- tests/ext/test_updater.py | 131 +++++++++++++++++++++----- 6 files changed, 213 insertions(+), 38 deletions(-) diff --git a/telegram/_utils/defaultvalue.py b/telegram/_utils/defaultvalue.py index f2fb80b76..965507341 100644 --- a/telegram/_utils/defaultvalue.py +++ b/telegram/_utils/defaultvalue.py @@ -132,3 +132,19 @@ DEFAULT_TRUE: DefaultValue[bool] = DefaultValue(True) .. versionadded:: 20.0 """ + + +DEFAULT_20: DefaultValue[int] = DefaultValue(20) +""":class:`DefaultValue`: Default :obj:`20`""" + +DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1") +""":class:`DefaultValue`: Default :obj:`127.0.0.1` + +.. versionadded:: NEXT.VERSION +""" + +DEFAULT_80: DefaultValue[int] = DefaultValue(80) +""":class:`DefaultValue`: Default :obj:`80` + +.. versionadded:: NEXT.VERSION +""" diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index c5834f9c8..52ce03321 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -52,7 +52,13 @@ from typing import ( ) from telegram._update import Update -from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue +from telegram._utils.defaultvalue import ( + DEFAULT_80, + DEFAULT_IP, + DEFAULT_NONE, + DEFAULT_TRUE, + DefaultValue, +) from telegram._utils.logging import get_logger from telegram._utils.repr import build_repr_with_selected_attrs from telegram._utils.types import SCT, DVType, ODVInput @@ -834,8 +840,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica def run_webhook( self, - listen: str = "127.0.0.1", - port: int = 80, + listen: DVType[str] = DEFAULT_IP, + port: DVType[int] = DEFAULT_80, url_path: str = "", cert: Optional[Union[str, Path]] = None, key: Optional[Union[str, Path]] = None, @@ -848,6 +854,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica close_loop: bool = True, stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE, secret_token: Optional[str] = None, + unix: Optional[Union[str, Path]] = None, ) -> None: """Convenience method that takes care of initializing and starting the app, listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and @@ -940,6 +947,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica header isn't set or it is set to a wrong token. .. versionadded:: 20.0 + unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path + does not need to exist, in which case the file will be created. + + Caution: + This parameter is a replacement for the default TCP bind. Therefore, it is + mutually exclusive with :paramref:`listen` and :paramref:`port`. When using + this param, you must also run a reverse proxy to the unix socket and set the + appropriate :paramref:`webhook_url`. + + .. versionadded:: NEXT.VERSION """ if not self.updater: raise RuntimeError( @@ -960,6 +977,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica ip_address=ip_address, max_connections=max_connections, secret_token=secret_token, + unix=unix, ), close_loop=close_loop, stop_signals=stop_signals, diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 29598f3c0..ea0b9519b 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -35,10 +35,10 @@ from typing import ( Union, ) -from telegram._utils.defaultvalue import DEFAULT_NONE +from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue from telegram._utils.logging import get_logger from telegram._utils.repr import build_repr_with_selected_attrs -from telegram._utils.types import ODVInput +from telegram._utils.types import DVType, ODVInput from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut try: @@ -456,8 +456,8 @@ class Updater(AsyncContextManager["Updater"]): async def start_webhook( self, - listen: str = "127.0.0.1", - port: int = 80, + listen: DVType[str] = DEFAULT_IP, + port: DVType[int] = DEFAULT_80, url_path: str = "", cert: Optional[Union[str, Path]] = None, key: Optional[Union[str, Path]] = None, @@ -468,6 +468,7 @@ class Updater(AsyncContextManager["Updater"]): ip_address: Optional[str] = None, max_connections: int = 40, secret_token: Optional[str] = None, + unix: Optional[Union[str, Path]] = None, ) -> "asyncio.Queue[object]": """ Starts a small http server to listen for updates via webhook. If :paramref:`cert` @@ -536,6 +537,16 @@ class Updater(AsyncContextManager["Updater"]): header isn't set or it is set to a wrong token. .. versionadded:: 20.0 + unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path + does not need to exist, in which case the file will be created. + + Caution: + This parameter is a replacement for the default TCP bind. Therefore, it is + mutually exclusive with :paramref:`listen` and :paramref:`port`. When using + this param, you must also run a reverse proxy to the unix socket and set the + appropriate :paramref:`webhook_url`. + + .. versionadded:: NEXT.VERSION Returns: :class:`queue.Queue`: The update queue that can be filled from the main thread. @@ -547,6 +558,21 @@ class Updater(AsyncContextManager["Updater"]): "To use `start_webhook`, PTB must be installed via `pip install " '"python-telegram-bot[webhooks]"`.' ) + # unix has special requirements what must and mustn't be set when using it + if unix: + error_msg = ( + "You can not pass unix and {0}, only use one. Unix if you want to " + "initialize a unix socket, or {0} for a standard TCP server." + ) + if not isinstance(listen, DefaultValue): + raise RuntimeError(error_msg.format("listen")) + if not isinstance(port, DefaultValue): + raise RuntimeError(error_msg.format("port")) + if not webhook_url: + raise RuntimeError( + "Since you set unix, you also need to set the URL to the webhook " + "of the proxy you run in front of the unix socket." + ) async with self.__lock: if self.running: @@ -561,8 +587,8 @@ class Updater(AsyncContextManager["Updater"]): webhook_ready = asyncio.Event() await self._start_webhook( - listen=listen, - port=port, + listen=DefaultValue.get_value(listen), + port=DefaultValue.get_value(port), url_path=url_path, cert=cert, key=key, @@ -574,6 +600,7 @@ class Updater(AsyncContextManager["Updater"]): ip_address=ip_address, max_connections=max_connections, secret_token=secret_token, + unix=unix, ) _LOGGER.debug("Waiting for webhook server to start") @@ -601,6 +628,7 @@ class Updater(AsyncContextManager["Updater"]): ip_address: Optional[str] = None, max_connections: int = 40, secret_token: Optional[str] = None, + unix: Optional[Union[str, Path]] = None, ) -> None: _LOGGER.debug("Updater thread started (webhook)") @@ -625,14 +653,13 @@ class Updater(AsyncContextManager["Updater"]): raise TelegramError("Invalid SSL Certificate") from exc else: ssl_ctx = None - # Create and start server - self._httpd = WebhookServer(listen, port, app, ssl_ctx) + self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix) if not webhook_url: webhook_url = self._gen_webhook_url( protocol="https" if ssl_ctx else "http", - listen=listen, + listen=DefaultValue.get_value(listen), port=port, url_path=url_path, ) diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index 65f37ce7b..e677c13a1 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -20,15 +20,23 @@ import asyncio import json from http import HTTPStatus +from pathlib import Path from ssl import SSLContext from types import TracebackType -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional, Type, Union # Instead of checking for ImportError here, we do that in `updater.py`, where we import from # this module. Doing it here would be tricky, as the classes below subclass tornado classes import tornado.web from tornado.httpserver import HTTPServer +try: + from tornado.netutil import bind_unix_socket + + UNIX_AVAILABLE = True +except ImportError: + UNIX_AVAILABLE = False + from telegram import Update from telegram._utils.logging import get_logger from telegram.ext._extbot import ExtBot @@ -50,21 +58,34 @@ class WebhookServer: "is_running", "_server_lock", "_shutdown_lock", + "unix", ) def __init__( - self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext] + self, + listen: str, + port: int, + webhook_app: "WebhookAppClass", + ssl_ctx: Optional[SSLContext], + unix: Optional[Union[str, Path]] = None, ): + if unix and not UNIX_AVAILABLE: + raise RuntimeError("This OS does not support binding unix sockets.") self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx) self.listen = listen self.port = port self.is_running = False + self.unix = unix self._server_lock = asyncio.Lock() self._shutdown_lock = asyncio.Lock() async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None: async with self._server_lock: - self._http_server.listen(self.port, address=self.listen) + if self.unix: + socket = bind_unix_socket(str(self.unix)) + self._http_server.add_socket(socket) + else: + self._http_server.listen(self.port, address=self.listen) self.is_running = True if ready is not None: diff --git a/tests/auxil/networking.py b/tests/auxil/networking.py index dec83df23..9966fb329 100644 --- a/tests/auxil/networking.py +++ b/tests/auxil/networking.py @@ -16,10 +16,11 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +from pathlib import Path from typing import Optional import pytest -from httpx import AsyncClient, Response +from httpx import AsyncClient, AsyncHTTPTransport, Response from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.types import ODVInput @@ -90,6 +91,7 @@ async def send_webhook_message( content_type: str = "application/json", get_method: Optional[str] = None, secret_token: Optional[str] = None, + unix: Optional[Path] = None, ) -> Response: headers = { "content-type": content_type, @@ -111,7 +113,9 @@ async def send_webhook_message( url = f"http://{ip}:{port}/{url_path}" - async with AsyncClient() as client: + transport = AsyncHTTPTransport(uds=unix) if unix else None + + async with AsyncClient(transport=transport) as client: return await client.request( url=url, method=get_method or "POST", data=payload, headers=headers ) diff --git a/tests/ext/test_updater.py b/tests/ext/test_updater.py index 58c8804f9..f4c7a906e 100644 --- a/tests/ext/test_updater.py +++ b/tests/ext/test_updater.py @@ -18,6 +18,7 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio import logging +import platform from collections import defaultdict from http import HTTPStatus from pathlib import Path @@ -32,7 +33,7 @@ from telegram.ext import ExtBot, InvalidCallbackData, Updater from telegram.request import HTTPXRequest from tests.auxil.build_messages import make_message, make_message_update from tests.auxil.envvars import TEST_WITH_OPT_DEPS -from tests.auxil.files import data_file +from tests.auxil.files import TEST_DATA_PATH, data_file from tests.auxil.networking import send_webhook_message from tests.auxil.pytest_classes import PytestBot, make_bot from tests.auxil.slots import mro_slots @@ -74,6 +75,14 @@ class TestUpdater: self.cb_handler_called = None self.test_flag = False + # This is needed instead of pytest's temp_path because the file path gets too long on macOS + # otherwise + @pytest.fixture() + def file_path(self) -> str: + path = TEST_DATA_PATH / "test.sock" + yield str(path) + path.unlink(missing_ok=True) + def error_callback(self, error): self.received = error self.err_handler_called.set() @@ -680,9 +689,13 @@ class TestUpdater: @pytest.mark.parametrize("ext_bot", [True, False]) @pytest.mark.parametrize("drop_pending_updates", [True, False]) @pytest.mark.parametrize("secret_token", ["SecretToken", None]) + @pytest.mark.parametrize("unix", [None, True]) async def test_webhook_basic( - self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token + self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path ): + # Skipping unix test on windows since they fail + if unix and platform.system() == "Windows": + pytest.skip("Windows doesn't support unix bind") # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler # that depends on this distinction works if ext_bot and not isinstance(updater.bot, ExtBot): @@ -706,34 +719,70 @@ class TestUpdater: port = randrange(1024, 49152) # Select random port async with updater: - return_value = await updater.start_webhook( - drop_pending_updates=drop_pending_updates, - ip_address=ip, - port=port, - url_path="TOKEN", - secret_token=secret_token, - ) + if unix: + return_value = await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + secret_token=secret_token, + url_path="TOKEN", + unix=file_path, + webhook_url="string", + ) + else: + return_value = await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + ip_address=ip, + port=port, + url_path="TOKEN", + secret_token=secret_token, + webhook_url="string", + ) assert return_value is updater.update_queue assert updater.running # Now, we send an update to the server update = make_message_update("Webhook") await send_webhook_message( - ip, port, update.to_json(), "TOKEN", secret_token=secret_token + ip, + port, + update.to_json(), + "TOKEN", + secret_token=secret_token, + unix=file_path if unix else None, ) assert (await updater.update_queue.get()).to_dict() == update.to_dict() # Returns Not Found if path is incorrect - response = await send_webhook_message(ip, port, "123456", "webhook_handler.py") + response = await send_webhook_message( + ip, + port, + "123456", + "webhook_handler.py", + unix=file_path if unix else None, + ) assert response.status_code == HTTPStatus.NOT_FOUND # Returns METHOD_NOT_ALLOWED if method is not allowed - response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD") + response = await send_webhook_message( + ip, + port, + None, + "TOKEN", + get_method="HEAD", + unix=file_path if unix else None, + ) assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED if secret_token: # Returns Forbidden if no secret token is set - response = await send_webhook_message(ip, port, update.to_json(), "TOKEN") + + response = await send_webhook_message( + ip, + port, + update.to_json(), + "TOKEN", + unix=file_path if unix else None, + ) + assert response.status_code == HTTPStatus.FORBIDDEN assert response.text == self.response_text.format( "Request did not include the secret token", HTTPStatus.FORBIDDEN @@ -741,7 +790,12 @@ class TestUpdater: # Returns Forbidden if the secret token is wrong response = await send_webhook_message( - ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken" + ip, + port, + update.to_json(), + "TOKEN", + secret_token="NotTheSecretToken", + unix=file_path if unix else None, ) assert response.status_code == HTTPStatus.FORBIDDEN assert response.text == self.response_text.format( @@ -757,19 +811,54 @@ class TestUpdater: assert self.message_count == 0 # We call the same logic twice to make sure that restarting the updater works as well - await updater.start_webhook( - drop_pending_updates=drop_pending_updates, - ip_address=ip, - port=port, - url_path="TOKEN", - ) + if unix: + await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + secret_token=secret_token, + unix=file_path, + webhook_url="string", + ) + else: + await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + ip_address=ip, + port=port, + url_path="TOKEN", + secret_token=secret_token, + webhook_url="string", + ) assert updater.running update = make_message_update("Webhook") - await send_webhook_message(ip, port, update.to_json(), "TOKEN") + await send_webhook_message( + ip, + port, + update.to_json(), + "" if unix else "TOKEN", + secret_token=secret_token, + unix=file_path if unix else None, + ) assert (await updater.update_queue.get()).to_dict() == update.to_dict() await updater.stop() assert not updater.running + async def test_unix_webhook_mutually_exclusive_params(self, updater): + async with updater: + with pytest.raises(RuntimeError, match="You can not pass unix and listen"): + await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter") + with pytest.raises(RuntimeError, match="You can not pass unix and port"): + await updater.start_webhook(port=20, unix="DoesntMatter") + with pytest.raises(RuntimeError, match="you set unix, you also need to set the URL"): + await updater.start_webhook(unix="DoesntMatter") + + @pytest.mark.skipif( + platform.system() != "Windows", + reason="Windows is the only platform without unix", + ) + async def test_no_unix(self, updater): + async with updater: + with pytest.raises(RuntimeError, match="binding unix sockets."): + await updater.start_webhook(unix="DoesntMatter", webhook_url="TOKEN") + async def test_start_webhook_already_running(self, updater, monkeypatch): async def return_true(*args, **kwargs): return True