From 2d8d43f2a5cb278a9681d7af44750f0fae9b190f Mon Sep 17 00:00:00 2001 From: Poolitzer Date: Sun, 24 Mar 2024 21:04:10 +0100 Subject: [PATCH] Accept Socket Objects for Webhooks (#4161) --- telegram/ext/_application.py | 19 ++++++++++++++++--- telegram/ext/_updater.py | 21 +++++++++++++++++---- telegram/ext/_utils/webhookhandler.py | 12 ++++++++---- tests/ext/test_updater.py | 22 ++++++++++++++++------ 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 4f951741d..abcc57360 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -75,6 +75,8 @@ from telegram.ext._utils.types import BD, BT, CCT, CD, JQ, RT, UD, ConversationK from telegram.warnings import PTBDeprecationWarning if TYPE_CHECKING: + from socket import socket + from telegram import Message from telegram.ext import ConversationHandler, JobQueue from telegram.ext._applicationbuilder import InitApplicationBuilder @@ -866,7 +868,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica close_loop: bool = True, stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE, secret_token: Optional[str] = None, - unix: Optional[Union[str, Path]] = None, + unix: Optional[Union[str, Path, "socket"]] = None, ) -> None: """Convenience method that takes care of initializing and starting the app, listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and @@ -959,8 +961,17 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica header isn't set or it is set to a wrong token. .. versionadded:: 20.0 - unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path - does not need to exist, in which case the file will be created. + unix (:class:`pathlib.Path` | :obj:`str` | :class:`socket.socket`, optional): Can be + either: + + * the path to the unix socket file as :class:`pathlib.Path` or :obj:`str`. This + will be passed to `tornado.netutil.bind_unix_socket `_ to create the socket. + If the Path does not exist, the file will be created. + + * or the socket itself. This option allows you to e.g. restrict the permissions of + the socket for improved security. Note that you need to pass the correct family, + type and socket options yourself. Caution: This parameter is a replacement for the default TCP bind. Therefore, it is @@ -969,6 +980,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica appropriate :paramref:`webhook_url`. .. versionadded:: 20.8 + .. versionchanged:: NEXT.VERSION + Added support to pass a socket instance itself. """ if not self.updater: raise RuntimeError( diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 7279c9f8a..1be427e4f 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -49,6 +49,8 @@ except ImportError: WEBHOOKS_AVAILABLE = False if TYPE_CHECKING: + from socket import socket + from telegram import Bot @@ -472,7 +474,7 @@ class Updater(AsyncContextManager["Updater"]): ip_address: Optional[str] = None, max_connections: int = 40, secret_token: Optional[str] = None, - unix: Optional[Union[str, Path]] = None, + unix: Optional[Union[str, Path, "socket"]] = None, ) -> "asyncio.Queue[object]": """ Starts a small http server to listen for updates via webhook. If :paramref:`cert` @@ -541,8 +543,17 @@ class Updater(AsyncContextManager["Updater"]): header isn't set or it is set to a wrong token. .. versionadded:: 20.0 - unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path - does not need to exist, in which case the file will be created. + unix (:class:`pathlib.Path` | :obj:`str` | :class:`socket.socket`, optional): Can be + either: + + * the path to the unix socket file as :class:`pathlib.Path` or :obj:`str`. This + will be passed to `tornado.netutil.bind_unix_socket `_ to create the socket. + If the Path does not exist, the file will be created. + + * or the socket itself. This option allows you to e.g. restrict the permissions of + the socket for improved security. Note that you need to pass the correct family, + type and socket options yourself. Caution: This parameter is a replacement for the default TCP bind. Therefore, it is @@ -551,6 +562,8 @@ class Updater(AsyncContextManager["Updater"]): appropriate :paramref:`webhook_url`. .. versionadded:: 20.8 + .. versionchanged:: NEXT.VERSION + Added support to pass a socket instance itself. Returns: :class:`queue.Queue`: The update queue that can be filled from the main thread. @@ -632,7 +645,7 @@ class Updater(AsyncContextManager["Updater"]): ip_address: Optional[str] = None, max_connections: int = 40, secret_token: Optional[str] = None, - unix: Optional[Union[str, Path]] = None, + unix: Optional[Union[str, Path, "socket"]] = None, ) -> None: _LOGGER.debug("Updater thread started (webhook)") diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index f630ec941..828dbca47 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -21,6 +21,7 @@ import asyncio import json from http import HTTPStatus from pathlib import Path +from socket import socket from ssl import SSLContext from types import TracebackType from typing import TYPE_CHECKING, Optional, Type, Union @@ -67,7 +68,7 @@ class WebhookServer: port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext], - unix: Optional[Union[str, Path]] = None, + unix: Optional[Union[str, Path, socket]] = None, ): if unix and not UNIX_AVAILABLE: raise RuntimeError("This OS does not support binding unix sockets.") @@ -75,15 +76,18 @@ class WebhookServer: self.listen = listen self.port = port self.is_running = False - self.unix = unix + self.unix = None + if unix and isinstance(unix, socket): + self.unix = unix + elif unix: + self.unix = bind_unix_socket(str(unix)) self._server_lock = asyncio.Lock() self._shutdown_lock = asyncio.Lock() async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None: async with self._server_lock: if self.unix: - socket = bind_unix_socket(str(self.unix)) - self._http_server.add_socket(socket) + self._http_server.add_socket(self.unix) else: self._http_server.listen(self.port, address=self.listen) diff --git a/tests/ext/test_updater.py b/tests/ext/test_updater.py index 0c81144ae..808eaef62 100644 --- a/tests/ext/test_updater.py +++ b/tests/ext/test_updater.py @@ -38,7 +38,16 @@ from tests.auxil.networking import send_webhook_message from tests.auxil.pytest_classes import PytestBot, make_bot from tests.auxil.slots import mro_slots +UNIX_AVAILABLE = False + if TEST_WITH_OPT_DEPS: + try: + from tornado.netutil import bind_unix_socket + + UNIX_AVAILABLE = True + except ImportError: + UNIX_AVAILABLE = False + from telegram.ext._utils.webhookhandler import WebhookServer @@ -692,13 +701,12 @@ class TestUpdater: @pytest.mark.parametrize("ext_bot", [True, False]) @pytest.mark.parametrize("drop_pending_updates", [True, False]) @pytest.mark.parametrize("secret_token", ["SecretToken", None]) - @pytest.mark.parametrize("unix", [None, True]) + @pytest.mark.parametrize( + "unix", [None, "file_path", "socket_object"] if UNIX_AVAILABLE else [None] + ) async def test_webhook_basic( self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path ): - # Skipping unix test on windows since they fail - if unix and platform.system() == "Windows": - pytest.skip("Windows doesn't support unix bind") # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler # that depends on this distinction works if ext_bot and not isinstance(updater.bot, ExtBot): @@ -723,11 +731,12 @@ class TestUpdater: async with updater: if unix: + socket = file_path if unix == "file_path" else bind_unix_socket(file_path) return_value = await updater.start_webhook( drop_pending_updates=drop_pending_updates, secret_token=secret_token, url_path="TOKEN", - unix=file_path, + unix=socket, webhook_url="string", ) else: @@ -815,10 +824,11 @@ class TestUpdater: # We call the same logic twice to make sure that restarting the updater works as well if unix: + socket = file_path if unix == "file_path" else bind_unix_socket(file_path) await updater.start_webhook( drop_pending_updates=drop_pending_updates, secret_token=secret_token, - unix=file_path, + unix=socket, webhook_url="string", ) else: