Add Support for Unix Sockets to Updater.start_webhook (#3986)

Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
Poolitzer 2023-12-14 21:37:00 +01:00 committed by GitHub
parent cc45f49a4f
commit 2345bfbb53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 213 additions and 38 deletions

View file

@ -132,3 +132,19 @@ DEFAULT_TRUE: DefaultValue[bool] = DefaultValue(True)
.. versionadded:: 20.0 .. versionadded:: 20.0
""" """
DEFAULT_20: DefaultValue[int] = DefaultValue(20)
""":class:`DefaultValue`: Default :obj:`20`"""
DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1")
""":class:`DefaultValue`: Default :obj:`127.0.0.1`
.. versionadded:: NEXT.VERSION
"""
DEFAULT_80: DefaultValue[int] = DefaultValue(80)
""":class:`DefaultValue`: Default :obj:`80`
.. versionadded:: NEXT.VERSION
"""

View file

@ -52,7 +52,13 @@ from typing import (
) )
from telegram._update import Update from telegram._update import Update
from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue from telegram._utils.defaultvalue import (
DEFAULT_80,
DEFAULT_IP,
DEFAULT_NONE,
DEFAULT_TRUE,
DefaultValue,
)
from telegram._utils.logging import get_logger from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import SCT, DVType, ODVInput from telegram._utils.types import SCT, DVType, ODVInput
@ -834,8 +840,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
def run_webhook( def run_webhook(
self, self,
listen: str = "127.0.0.1", listen: DVType[str] = DEFAULT_IP,
port: int = 80, port: DVType[int] = DEFAULT_80,
url_path: str = "", url_path: str = "",
cert: Optional[Union[str, Path]] = None, cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None, key: Optional[Union[str, Path]] = None,
@ -848,6 +854,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
close_loop: bool = True, close_loop: bool = True,
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE, stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
secret_token: Optional[str] = None, secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None: ) -> None:
"""Convenience method that takes care of initializing and starting the app, """Convenience method that takes care of initializing and starting the app,
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
@ -940,6 +947,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
header isn't set or it is set to a wrong token. header isn't set or it is set to a wrong token.
.. versionadded:: 20.0 .. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
""" """
if not self.updater: if not self.updater:
raise RuntimeError( raise RuntimeError(
@ -960,6 +977,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
ip_address=ip_address, ip_address=ip_address,
max_connections=max_connections, max_connections=max_connections,
secret_token=secret_token, secret_token=secret_token,
unix=unix,
), ),
close_loop=close_loop, close_loop=close_loop,
stop_signals=stop_signals, stop_signals=stop_signals,

View file

@ -35,10 +35,10 @@ from typing import (
Union, Union,
) )
from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue
from telegram._utils.logging import get_logger from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import ODVInput from telegram._utils.types import DVType, ODVInput
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
try: try:
@ -456,8 +456,8 @@ class Updater(AsyncContextManager["Updater"]):
async def start_webhook( async def start_webhook(
self, self,
listen: str = "127.0.0.1", listen: DVType[str] = DEFAULT_IP,
port: int = 80, port: DVType[int] = DEFAULT_80,
url_path: str = "", url_path: str = "",
cert: Optional[Union[str, Path]] = None, cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None, key: Optional[Union[str, Path]] = None,
@ -468,6 +468,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address: Optional[str] = None, ip_address: Optional[str] = None,
max_connections: int = 40, max_connections: int = 40,
secret_token: Optional[str] = None, secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> "asyncio.Queue[object]": ) -> "asyncio.Queue[object]":
""" """
Starts a small http server to listen for updates via webhook. If :paramref:`cert` Starts a small http server to listen for updates via webhook. If :paramref:`cert`
@ -536,6 +537,16 @@ class Updater(AsyncContextManager["Updater"]):
header isn't set or it is set to a wrong token. header isn't set or it is set to a wrong token.
.. versionadded:: 20.0 .. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
Returns: Returns:
:class:`queue.Queue`: The update queue that can be filled from the main thread. :class:`queue.Queue`: The update queue that can be filled from the main thread.
@ -547,6 +558,21 @@ class Updater(AsyncContextManager["Updater"]):
"To use `start_webhook`, PTB must be installed via `pip install " "To use `start_webhook`, PTB must be installed via `pip install "
'"python-telegram-bot[webhooks]"`.' '"python-telegram-bot[webhooks]"`.'
) )
# unix has special requirements what must and mustn't be set when using it
if unix:
error_msg = (
"You can not pass unix and {0}, only use one. Unix if you want to "
"initialize a unix socket, or {0} for a standard TCP server."
)
if not isinstance(listen, DefaultValue):
raise RuntimeError(error_msg.format("listen"))
if not isinstance(port, DefaultValue):
raise RuntimeError(error_msg.format("port"))
if not webhook_url:
raise RuntimeError(
"Since you set unix, you also need to set the URL to the webhook "
"of the proxy you run in front of the unix socket."
)
async with self.__lock: async with self.__lock:
if self.running: if self.running:
@ -561,8 +587,8 @@ class Updater(AsyncContextManager["Updater"]):
webhook_ready = asyncio.Event() webhook_ready = asyncio.Event()
await self._start_webhook( await self._start_webhook(
listen=listen, listen=DefaultValue.get_value(listen),
port=port, port=DefaultValue.get_value(port),
url_path=url_path, url_path=url_path,
cert=cert, cert=cert,
key=key, key=key,
@ -574,6 +600,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address=ip_address, ip_address=ip_address,
max_connections=max_connections, max_connections=max_connections,
secret_token=secret_token, secret_token=secret_token,
unix=unix,
) )
_LOGGER.debug("Waiting for webhook server to start") _LOGGER.debug("Waiting for webhook server to start")
@ -601,6 +628,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address: Optional[str] = None, ip_address: Optional[str] = None,
max_connections: int = 40, max_connections: int = 40,
secret_token: Optional[str] = None, secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None: ) -> None:
_LOGGER.debug("Updater thread started (webhook)") _LOGGER.debug("Updater thread started (webhook)")
@ -625,14 +653,13 @@ class Updater(AsyncContextManager["Updater"]):
raise TelegramError("Invalid SSL Certificate") from exc raise TelegramError("Invalid SSL Certificate") from exc
else: else:
ssl_ctx = None ssl_ctx = None
# Create and start server # Create and start server
self._httpd = WebhookServer(listen, port, app, ssl_ctx) self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)
if not webhook_url: if not webhook_url:
webhook_url = self._gen_webhook_url( webhook_url = self._gen_webhook_url(
protocol="https" if ssl_ctx else "http", protocol="https" if ssl_ctx else "http",
listen=listen, listen=DefaultValue.get_value(listen),
port=port, port=port,
url_path=url_path, url_path=url_path,
) )

View file

@ -20,15 +20,23 @@
import asyncio import asyncio
import json import json
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path
from ssl import SSLContext from ssl import SSLContext
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type, Union
# Instead of checking for ImportError here, we do that in `updater.py`, where we import from # Instead of checking for ImportError here, we do that in `updater.py`, where we import from
# this module. Doing it here would be tricky, as the classes below subclass tornado classes # this module. Doing it here would be tricky, as the classes below subclass tornado classes
import tornado.web import tornado.web
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
try:
from tornado.netutil import bind_unix_socket
UNIX_AVAILABLE = True
except ImportError:
UNIX_AVAILABLE = False
from telegram import Update from telegram import Update
from telegram._utils.logging import get_logger from telegram._utils.logging import get_logger
from telegram.ext._extbot import ExtBot from telegram.ext._extbot import ExtBot
@ -50,20 +58,33 @@ class WebhookServer:
"is_running", "is_running",
"_server_lock", "_server_lock",
"_shutdown_lock", "_shutdown_lock",
"unix",
) )
def __init__( def __init__(
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext] self,
listen: str,
port: int,
webhook_app: "WebhookAppClass",
ssl_ctx: Optional[SSLContext],
unix: Optional[Union[str, Path]] = None,
): ):
if unix and not UNIX_AVAILABLE:
raise RuntimeError("This OS does not support binding unix sockets.")
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx) self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
self.listen = listen self.listen = listen
self.port = port self.port = port
self.is_running = False self.is_running = False
self.unix = unix
self._server_lock = asyncio.Lock() self._server_lock = asyncio.Lock()
self._shutdown_lock = asyncio.Lock() self._shutdown_lock = asyncio.Lock()
async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None: async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
async with self._server_lock: async with self._server_lock:
if self.unix:
socket = bind_unix_socket(str(self.unix))
self._http_server.add_socket(socket)
else:
self._http_server.listen(self.port, address=self.listen) self._http_server.listen(self.port, address=self.listen)
self.is_running = True self.is_running = True

View file

@ -16,10 +16,11 @@
# #
# You should have received a copy of the GNU Lesser Public License # You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
from pathlib import Path
from typing import Optional from typing import Optional
import pytest import pytest
from httpx import AsyncClient, Response from httpx import AsyncClient, AsyncHTTPTransport, Response
from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.types import ODVInput from telegram._utils.types import ODVInput
@ -90,6 +91,7 @@ async def send_webhook_message(
content_type: str = "application/json", content_type: str = "application/json",
get_method: Optional[str] = None, get_method: Optional[str] = None,
secret_token: Optional[str] = None, secret_token: Optional[str] = None,
unix: Optional[Path] = None,
) -> Response: ) -> Response:
headers = { headers = {
"content-type": content_type, "content-type": content_type,
@ -111,7 +113,9 @@ async def send_webhook_message(
url = f"http://{ip}:{port}/{url_path}" url = f"http://{ip}:{port}/{url_path}"
async with AsyncClient() as client: transport = AsyncHTTPTransport(uds=unix) if unix else None
async with AsyncClient(transport=transport) as client:
return await client.request( return await client.request(
url=url, method=get_method or "POST", data=payload, headers=headers url=url, method=get_method or "POST", data=payload, headers=headers
) )

View file

@ -18,6 +18,7 @@
# along with this program. If not, see [http://www.gnu.org/licenses/]. # along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio import asyncio
import logging import logging
import platform
from collections import defaultdict from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
@ -32,7 +33,7 @@ from telegram.ext import ExtBot, InvalidCallbackData, Updater
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from tests.auxil.build_messages import make_message, make_message_update from tests.auxil.build_messages import make_message, make_message_update
from tests.auxil.envvars import TEST_WITH_OPT_DEPS from tests.auxil.envvars import TEST_WITH_OPT_DEPS
from tests.auxil.files import data_file from tests.auxil.files import TEST_DATA_PATH, data_file
from tests.auxil.networking import send_webhook_message from tests.auxil.networking import send_webhook_message
from tests.auxil.pytest_classes import PytestBot, make_bot from tests.auxil.pytest_classes import PytestBot, make_bot
from tests.auxil.slots import mro_slots from tests.auxil.slots import mro_slots
@ -74,6 +75,14 @@ class TestUpdater:
self.cb_handler_called = None self.cb_handler_called = None
self.test_flag = False self.test_flag = False
# This is needed instead of pytest's temp_path because the file path gets too long on macOS
# otherwise
@pytest.fixture()
def file_path(self) -> str:
path = TEST_DATA_PATH / "test.sock"
yield str(path)
path.unlink(missing_ok=True)
def error_callback(self, error): def error_callback(self, error):
self.received = error self.received = error
self.err_handler_called.set() self.err_handler_called.set()
@ -680,9 +689,13 @@ class TestUpdater:
@pytest.mark.parametrize("ext_bot", [True, False]) @pytest.mark.parametrize("ext_bot", [True, False])
@pytest.mark.parametrize("drop_pending_updates", [True, False]) @pytest.mark.parametrize("drop_pending_updates", [True, False])
@pytest.mark.parametrize("secret_token", ["SecretToken", None]) @pytest.mark.parametrize("secret_token", ["SecretToken", None])
@pytest.mark.parametrize("unix", [None, True])
async def test_webhook_basic( async def test_webhook_basic(
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path
): ):
# Skipping unix test on windows since they fail
if unix and platform.system() == "Windows":
pytest.skip("Windows doesn't support unix bind")
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
# that depends on this distinction works # that depends on this distinction works
if ext_bot and not isinstance(updater.bot, ExtBot): if ext_bot and not isinstance(updater.bot, ExtBot):
@ -706,12 +719,22 @@ class TestUpdater:
port = randrange(1024, 49152) # Select random port port = randrange(1024, 49152) # Select random port
async with updater: async with updater:
if unix:
return_value = await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
secret_token=secret_token,
url_path="TOKEN",
unix=file_path,
webhook_url="string",
)
else:
return_value = await updater.start_webhook( return_value = await updater.start_webhook(
drop_pending_updates=drop_pending_updates, drop_pending_updates=drop_pending_updates,
ip_address=ip, ip_address=ip,
port=port, port=port,
url_path="TOKEN", url_path="TOKEN",
secret_token=secret_token, secret_token=secret_token,
webhook_url="string",
) )
assert return_value is updater.update_queue assert return_value is updater.update_queue
assert updater.running assert updater.running
@ -719,21 +742,47 @@ class TestUpdater:
# Now, we send an update to the server # Now, we send an update to the server
update = make_message_update("Webhook") update = make_message_update("Webhook")
await send_webhook_message( await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token=secret_token ip,
port,
update.to_json(),
"TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
) )
assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert (await updater.update_queue.get()).to_dict() == update.to_dict()
# Returns Not Found if path is incorrect # Returns Not Found if path is incorrect
response = await send_webhook_message(ip, port, "123456", "webhook_handler.py") response = await send_webhook_message(
ip,
port,
"123456",
"webhook_handler.py",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.NOT_FOUND assert response.status_code == HTTPStatus.NOT_FOUND
# Returns METHOD_NOT_ALLOWED if method is not allowed # Returns METHOD_NOT_ALLOWED if method is not allowed
response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD") response = await send_webhook_message(
ip,
port,
None,
"TOKEN",
get_method="HEAD",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
if secret_token: if secret_token:
# Returns Forbidden if no secret token is set # Returns Forbidden if no secret token is set
response = await send_webhook_message(ip, port, update.to_json(), "TOKEN")
response = await send_webhook_message(
ip,
port,
update.to_json(),
"TOKEN",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == self.response_text.format( assert response.text == self.response_text.format(
"Request did not include the secret token", HTTPStatus.FORBIDDEN "Request did not include the secret token", HTTPStatus.FORBIDDEN
@ -741,7 +790,12 @@ class TestUpdater:
# Returns Forbidden if the secret token is wrong # Returns Forbidden if the secret token is wrong
response = await send_webhook_message( response = await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken" ip,
port,
update.to_json(),
"TOKEN",
secret_token="NotTheSecretToken",
unix=file_path if unix else None,
) )
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == self.response_text.format( assert response.text == self.response_text.format(
@ -757,19 +811,54 @@ class TestUpdater:
assert self.message_count == 0 assert self.message_count == 0
# We call the same logic twice to make sure that restarting the updater works as well # We call the same logic twice to make sure that restarting the updater works as well
if unix:
await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
secret_token=secret_token,
unix=file_path,
webhook_url="string",
)
else:
await updater.start_webhook( await updater.start_webhook(
drop_pending_updates=drop_pending_updates, drop_pending_updates=drop_pending_updates,
ip_address=ip, ip_address=ip,
port=port, port=port,
url_path="TOKEN", url_path="TOKEN",
secret_token=secret_token,
webhook_url="string",
) )
assert updater.running assert updater.running
update = make_message_update("Webhook") update = make_message_update("Webhook")
await send_webhook_message(ip, port, update.to_json(), "TOKEN") await send_webhook_message(
ip,
port,
update.to_json(),
"" if unix else "TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
)
assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert (await updater.update_queue.get()).to_dict() == update.to_dict()
await updater.stop() await updater.stop()
assert not updater.running assert not updater.running
async def test_unix_webhook_mutually_exclusive_params(self, updater):
async with updater:
with pytest.raises(RuntimeError, match="You can not pass unix and listen"):
await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter")
with pytest.raises(RuntimeError, match="You can not pass unix and port"):
await updater.start_webhook(port=20, unix="DoesntMatter")
with pytest.raises(RuntimeError, match="you set unix, you also need to set the URL"):
await updater.start_webhook(unix="DoesntMatter")
@pytest.mark.skipif(
platform.system() != "Windows",
reason="Windows is the only platform without unix",
)
async def test_no_unix(self, updater):
async with updater:
with pytest.raises(RuntimeError, match="binding unix sockets."):
await updater.start_webhook(unix="DoesntMatter", webhook_url="TOKEN")
async def test_start_webhook_already_running(self, updater, monkeypatch): async def test_start_webhook_already_running(self, updater, monkeypatch):
async def return_true(*args, **kwargs): async def return_true(*args, **kwargs):
return True return True