Add Support for Unix Sockets to Updater.start_webhook (#3986)

Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
Poolitzer 2023-12-14 21:37:00 +01:00 committed by GitHub
parent cc45f49a4f
commit 2345bfbb53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 213 additions and 38 deletions

View file

@ -132,3 +132,19 @@ DEFAULT_TRUE: DefaultValue[bool] = DefaultValue(True)
.. versionadded:: 20.0
"""
DEFAULT_20: DefaultValue[int] = DefaultValue(20)
""":class:`DefaultValue`: Default :obj:`20`"""
DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1")
""":class:`DefaultValue`: Default :obj:`127.0.0.1`
.. versionadded:: NEXT.VERSION
"""
DEFAULT_80: DefaultValue[int] = DefaultValue(80)
""":class:`DefaultValue`: Default :obj:`80`
.. versionadded:: NEXT.VERSION
"""

View file

@ -52,7 +52,13 @@ from typing import (
)
from telegram._update import Update
from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue
from telegram._utils.defaultvalue import (
DEFAULT_80,
DEFAULT_IP,
DEFAULT_NONE,
DEFAULT_TRUE,
DefaultValue,
)
from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import SCT, DVType, ODVInput
@ -834,8 +840,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
def run_webhook(
self,
listen: str = "127.0.0.1",
port: int = 80,
listen: DVType[str] = DEFAULT_IP,
port: DVType[int] = DEFAULT_80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None,
@ -848,6 +854,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
close_loop: bool = True,
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
"""Convenience method that takes care of initializing and starting the app,
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
@ -940,6 +947,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
"""
if not self.updater:
raise RuntimeError(
@ -960,6 +977,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
),
close_loop=close_loop,
stop_signals=stop_signals,

View file

@ -35,10 +35,10 @@ from typing import (
Union,
)
from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import ODVInput
from telegram._utils.types import DVType, ODVInput
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
try:
@ -456,8 +456,8 @@ class Updater(AsyncContextManager["Updater"]):
async def start_webhook(
self,
listen: str = "127.0.0.1",
port: int = 80,
listen: DVType[str] = DEFAULT_IP,
port: DVType[int] = DEFAULT_80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None,
@ -468,6 +468,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> "asyncio.Queue[object]":
"""
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
@ -536,6 +537,16 @@ class Updater(AsyncContextManager["Updater"]):
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
Returns:
:class:`queue.Queue`: The update queue that can be filled from the main thread.
@ -547,6 +558,21 @@ class Updater(AsyncContextManager["Updater"]):
"To use `start_webhook`, PTB must be installed via `pip install "
'"python-telegram-bot[webhooks]"`.'
)
# unix has special requirements what must and mustn't be set when using it
if unix:
error_msg = (
"You can not pass unix and {0}, only use one. Unix if you want to "
"initialize a unix socket, or {0} for a standard TCP server."
)
if not isinstance(listen, DefaultValue):
raise RuntimeError(error_msg.format("listen"))
if not isinstance(port, DefaultValue):
raise RuntimeError(error_msg.format("port"))
if not webhook_url:
raise RuntimeError(
"Since you set unix, you also need to set the URL to the webhook "
"of the proxy you run in front of the unix socket."
)
async with self.__lock:
if self.running:
@ -561,8 +587,8 @@ class Updater(AsyncContextManager["Updater"]):
webhook_ready = asyncio.Event()
await self._start_webhook(
listen=listen,
port=port,
listen=DefaultValue.get_value(listen),
port=DefaultValue.get_value(port),
url_path=url_path,
cert=cert,
key=key,
@ -574,6 +600,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
)
_LOGGER.debug("Waiting for webhook server to start")
@ -601,6 +628,7 @@ class Updater(AsyncContextManager["Updater"]):
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
_LOGGER.debug("Updater thread started (webhook)")
@ -625,14 +653,13 @@ class Updater(AsyncContextManager["Updater"]):
raise TelegramError("Invalid SSL Certificate") from exc
else:
ssl_ctx = None
# Create and start server
self._httpd = WebhookServer(listen, port, app, ssl_ctx)
self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)
if not webhook_url:
webhook_url = self._gen_webhook_url(
protocol="https" if ssl_ctx else "http",
listen=listen,
listen=DefaultValue.get_value(listen),
port=port,
url_path=url_path,
)

View file

@ -20,15 +20,23 @@
import asyncio
import json
from http import HTTPStatus
from pathlib import Path
from ssl import SSLContext
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, Union
# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
import tornado.web
from tornado.httpserver import HTTPServer
try:
from tornado.netutil import bind_unix_socket
UNIX_AVAILABLE = True
except ImportError:
UNIX_AVAILABLE = False
from telegram import Update
from telegram._utils.logging import get_logger
from telegram.ext._extbot import ExtBot
@ -50,21 +58,34 @@ class WebhookServer:
"is_running",
"_server_lock",
"_shutdown_lock",
"unix",
)
def __init__(
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext]
self,
listen: str,
port: int,
webhook_app: "WebhookAppClass",
ssl_ctx: Optional[SSLContext],
unix: Optional[Union[str, Path]] = None,
):
if unix and not UNIX_AVAILABLE:
raise RuntimeError("This OS does not support binding unix sockets.")
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
self.listen = listen
self.port = port
self.is_running = False
self.unix = unix
self._server_lock = asyncio.Lock()
self._shutdown_lock = asyncio.Lock()
async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
async with self._server_lock:
self._http_server.listen(self.port, address=self.listen)
if self.unix:
socket = bind_unix_socket(str(self.unix))
self._http_server.add_socket(socket)
else:
self._http_server.listen(self.port, address=self.listen)
self.is_running = True
if ready is not None:

View file

@ -16,10 +16,11 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
from pathlib import Path
from typing import Optional
import pytest
from httpx import AsyncClient, Response
from httpx import AsyncClient, AsyncHTTPTransport, Response
from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.types import ODVInput
@ -90,6 +91,7 @@ async def send_webhook_message(
content_type: str = "application/json",
get_method: Optional[str] = None,
secret_token: Optional[str] = None,
unix: Optional[Path] = None,
) -> Response:
headers = {
"content-type": content_type,
@ -111,7 +113,9 @@ async def send_webhook_message(
url = f"http://{ip}:{port}/{url_path}"
async with AsyncClient() as client:
transport = AsyncHTTPTransport(uds=unix) if unix else None
async with AsyncClient(transport=transport) as client:
return await client.request(
url=url, method=get_method or "POST", data=payload, headers=headers
)

View file

@ -18,6 +18,7 @@
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import logging
import platform
from collections import defaultdict
from http import HTTPStatus
from pathlib import Path
@ -32,7 +33,7 @@ from telegram.ext import ExtBot, InvalidCallbackData, Updater
from telegram.request import HTTPXRequest
from tests.auxil.build_messages import make_message, make_message_update
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
from tests.auxil.files import data_file
from tests.auxil.files import TEST_DATA_PATH, data_file
from tests.auxil.networking import send_webhook_message
from tests.auxil.pytest_classes import PytestBot, make_bot
from tests.auxil.slots import mro_slots
@ -74,6 +75,14 @@ class TestUpdater:
self.cb_handler_called = None
self.test_flag = False
# This is needed instead of pytest's temp_path because the file path gets too long on macOS
# otherwise
@pytest.fixture()
def file_path(self) -> str:
path = TEST_DATA_PATH / "test.sock"
yield str(path)
path.unlink(missing_ok=True)
def error_callback(self, error):
self.received = error
self.err_handler_called.set()
@ -680,9 +689,13 @@ class TestUpdater:
@pytest.mark.parametrize("ext_bot", [True, False])
@pytest.mark.parametrize("drop_pending_updates", [True, False])
@pytest.mark.parametrize("secret_token", ["SecretToken", None])
@pytest.mark.parametrize("unix", [None, True])
async def test_webhook_basic(
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path
):
# Skipping unix test on windows since they fail
if unix and platform.system() == "Windows":
pytest.skip("Windows doesn't support unix bind")
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
# that depends on this distinction works
if ext_bot and not isinstance(updater.bot, ExtBot):
@ -706,34 +719,70 @@ class TestUpdater:
port = randrange(1024, 49152) # Select random port
async with updater:
return_value = await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
ip_address=ip,
port=port,
url_path="TOKEN",
secret_token=secret_token,
)
if unix:
return_value = await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
secret_token=secret_token,
url_path="TOKEN",
unix=file_path,
webhook_url="string",
)
else:
return_value = await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
ip_address=ip,
port=port,
url_path="TOKEN",
secret_token=secret_token,
webhook_url="string",
)
assert return_value is updater.update_queue
assert updater.running
# Now, we send an update to the server
update = make_message_update("Webhook")
await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token=secret_token
ip,
port,
update.to_json(),
"TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
)
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
# Returns Not Found if path is incorrect
response = await send_webhook_message(ip, port, "123456", "webhook_handler.py")
response = await send_webhook_message(
ip,
port,
"123456",
"webhook_handler.py",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.NOT_FOUND
# Returns METHOD_NOT_ALLOWED if method is not allowed
response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD")
response = await send_webhook_message(
ip,
port,
None,
"TOKEN",
get_method="HEAD",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
if secret_token:
# Returns Forbidden if no secret token is set
response = await send_webhook_message(ip, port, update.to_json(), "TOKEN")
response = await send_webhook_message(
ip,
port,
update.to_json(),
"TOKEN",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == self.response_text.format(
"Request did not include the secret token", HTTPStatus.FORBIDDEN
@ -741,7 +790,12 @@ class TestUpdater:
# Returns Forbidden if the secret token is wrong
response = await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken"
ip,
port,
update.to_json(),
"TOKEN",
secret_token="NotTheSecretToken",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == self.response_text.format(
@ -757,19 +811,54 @@ class TestUpdater:
assert self.message_count == 0
# We call the same logic twice to make sure that restarting the updater works as well
await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
ip_address=ip,
port=port,
url_path="TOKEN",
)
if unix:
await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
secret_token=secret_token,
unix=file_path,
webhook_url="string",
)
else:
await updater.start_webhook(
drop_pending_updates=drop_pending_updates,
ip_address=ip,
port=port,
url_path="TOKEN",
secret_token=secret_token,
webhook_url="string",
)
assert updater.running
update = make_message_update("Webhook")
await send_webhook_message(ip, port, update.to_json(), "TOKEN")
await send_webhook_message(
ip,
port,
update.to_json(),
"" if unix else "TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
)
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
await updater.stop()
assert not updater.running
async def test_unix_webhook_mutually_exclusive_params(self, updater):
async with updater:
with pytest.raises(RuntimeError, match="You can not pass unix and listen"):
await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter")
with pytest.raises(RuntimeError, match="You can not pass unix and port"):
await updater.start_webhook(port=20, unix="DoesntMatter")
with pytest.raises(RuntimeError, match="you set unix, you also need to set the URL"):
await updater.start_webhook(unix="DoesntMatter")
@pytest.mark.skipif(
platform.system() != "Windows",
reason="Windows is the only platform without unix",
)
async def test_no_unix(self, updater):
async with updater:
with pytest.raises(RuntimeError, match="binding unix sockets."):
await updater.start_webhook(unix="DoesntMatter", webhook_url="TOKEN")
async def test_start_webhook_already_running(self, updater, monkeypatch):
async def return_true(*args, **kwargs):
return True