mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2024-12-22 06:25:12 +01:00
Add Support for Unix Sockets to Updater.start_webhook
(#3986)
Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
parent
cc45f49a4f
commit
2345bfbb53
6 changed files with 213 additions and 38 deletions
|
@ -132,3 +132,19 @@ DEFAULT_TRUE: DefaultValue[bool] = DefaultValue(True)
|
|||
|
||||
.. versionadded:: 20.0
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_20: DefaultValue[int] = DefaultValue(20)
|
||||
""":class:`DefaultValue`: Default :obj:`20`"""
|
||||
|
||||
DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1")
|
||||
""":class:`DefaultValue`: Default :obj:`127.0.0.1`
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
|
||||
DEFAULT_80: DefaultValue[int] = DefaultValue(80)
|
||||
""":class:`DefaultValue`: Default :obj:`80`
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
|
|
|
@ -52,7 +52,13 @@ from typing import (
|
|||
)
|
||||
|
||||
from telegram._update import Update
|
||||
from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue
|
||||
from telegram._utils.defaultvalue import (
|
||||
DEFAULT_80,
|
||||
DEFAULT_IP,
|
||||
DEFAULT_NONE,
|
||||
DEFAULT_TRUE,
|
||||
DefaultValue,
|
||||
)
|
||||
from telegram._utils.logging import get_logger
|
||||
from telegram._utils.repr import build_repr_with_selected_attrs
|
||||
from telegram._utils.types import SCT, DVType, ODVInput
|
||||
|
@ -834,8 +840,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
|
||||
def run_webhook(
|
||||
self,
|
||||
listen: str = "127.0.0.1",
|
||||
port: int = 80,
|
||||
listen: DVType[str] = DEFAULT_IP,
|
||||
port: DVType[int] = DEFAULT_80,
|
||||
url_path: str = "",
|
||||
cert: Optional[Union[str, Path]] = None,
|
||||
key: Optional[Union[str, Path]] = None,
|
||||
|
@ -848,6 +854,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
close_loop: bool = True,
|
||||
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
|
||||
secret_token: Optional[str] = None,
|
||||
unix: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""Convenience method that takes care of initializing and starting the app,
|
||||
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
|
||||
|
@ -940,6 +947,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
header isn't set or it is set to a wrong token.
|
||||
|
||||
.. versionadded:: 20.0
|
||||
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
|
||||
does not need to exist, in which case the file will be created.
|
||||
|
||||
Caution:
|
||||
This parameter is a replacement for the default TCP bind. Therefore, it is
|
||||
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
|
||||
this param, you must also run a reverse proxy to the unix socket and set the
|
||||
appropriate :paramref:`webhook_url`.
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
"""
|
||||
if not self.updater:
|
||||
raise RuntimeError(
|
||||
|
@ -960,6 +977,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
|||
ip_address=ip_address,
|
||||
max_connections=max_connections,
|
||||
secret_token=secret_token,
|
||||
unix=unix,
|
||||
),
|
||||
close_loop=close_loop,
|
||||
stop_signals=stop_signals,
|
||||
|
|
|
@ -35,10 +35,10 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from telegram._utils.defaultvalue import DEFAULT_NONE
|
||||
from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue
|
||||
from telegram._utils.logging import get_logger
|
||||
from telegram._utils.repr import build_repr_with_selected_attrs
|
||||
from telegram._utils.types import ODVInput
|
||||
from telegram._utils.types import DVType, ODVInput
|
||||
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
|
||||
|
||||
try:
|
||||
|
@ -456,8 +456,8 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
|
||||
async def start_webhook(
|
||||
self,
|
||||
listen: str = "127.0.0.1",
|
||||
port: int = 80,
|
||||
listen: DVType[str] = DEFAULT_IP,
|
||||
port: DVType[int] = DEFAULT_80,
|
||||
url_path: str = "",
|
||||
cert: Optional[Union[str, Path]] = None,
|
||||
key: Optional[Union[str, Path]] = None,
|
||||
|
@ -468,6 +468,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
ip_address: Optional[str] = None,
|
||||
max_connections: int = 40,
|
||||
secret_token: Optional[str] = None,
|
||||
unix: Optional[Union[str, Path]] = None,
|
||||
) -> "asyncio.Queue[object]":
|
||||
"""
|
||||
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
|
||||
|
@ -536,6 +537,16 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
header isn't set or it is set to a wrong token.
|
||||
|
||||
.. versionadded:: 20.0
|
||||
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
|
||||
does not need to exist, in which case the file will be created.
|
||||
|
||||
Caution:
|
||||
This parameter is a replacement for the default TCP bind. Therefore, it is
|
||||
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
|
||||
this param, you must also run a reverse proxy to the unix socket and set the
|
||||
appropriate :paramref:`webhook_url`.
|
||||
|
||||
.. versionadded:: NEXT.VERSION
|
||||
Returns:
|
||||
:class:`queue.Queue`: The update queue that can be filled from the main thread.
|
||||
|
||||
|
@ -547,6 +558,21 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
"To use `start_webhook`, PTB must be installed via `pip install "
|
||||
'"python-telegram-bot[webhooks]"`.'
|
||||
)
|
||||
# unix has special requirements what must and mustn't be set when using it
|
||||
if unix:
|
||||
error_msg = (
|
||||
"You can not pass unix and {0}, only use one. Unix if you want to "
|
||||
"initialize a unix socket, or {0} for a standard TCP server."
|
||||
)
|
||||
if not isinstance(listen, DefaultValue):
|
||||
raise RuntimeError(error_msg.format("listen"))
|
||||
if not isinstance(port, DefaultValue):
|
||||
raise RuntimeError(error_msg.format("port"))
|
||||
if not webhook_url:
|
||||
raise RuntimeError(
|
||||
"Since you set unix, you also need to set the URL to the webhook "
|
||||
"of the proxy you run in front of the unix socket."
|
||||
)
|
||||
|
||||
async with self.__lock:
|
||||
if self.running:
|
||||
|
@ -561,8 +587,8 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
webhook_ready = asyncio.Event()
|
||||
|
||||
await self._start_webhook(
|
||||
listen=listen,
|
||||
port=port,
|
||||
listen=DefaultValue.get_value(listen),
|
||||
port=DefaultValue.get_value(port),
|
||||
url_path=url_path,
|
||||
cert=cert,
|
||||
key=key,
|
||||
|
@ -574,6 +600,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
ip_address=ip_address,
|
||||
max_connections=max_connections,
|
||||
secret_token=secret_token,
|
||||
unix=unix,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Waiting for webhook server to start")
|
||||
|
@ -601,6 +628,7 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
ip_address: Optional[str] = None,
|
||||
max_connections: int = 40,
|
||||
secret_token: Optional[str] = None,
|
||||
unix: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
_LOGGER.debug("Updater thread started (webhook)")
|
||||
|
||||
|
@ -625,14 +653,13 @@ class Updater(AsyncContextManager["Updater"]):
|
|||
raise TelegramError("Invalid SSL Certificate") from exc
|
||||
else:
|
||||
ssl_ctx = None
|
||||
|
||||
# Create and start server
|
||||
self._httpd = WebhookServer(listen, port, app, ssl_ctx)
|
||||
self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)
|
||||
|
||||
if not webhook_url:
|
||||
webhook_url = self._gen_webhook_url(
|
||||
protocol="https" if ssl_ctx else "http",
|
||||
listen=listen,
|
||||
listen=DefaultValue.get_value(listen),
|
||||
port=port,
|
||||
url_path=url_path,
|
||||
)
|
||||
|
|
|
@ -20,15 +20,23 @@
|
|||
import asyncio
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from ssl import SSLContext
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||
|
||||
# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
|
||||
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
|
||||
import tornado.web
|
||||
from tornado.httpserver import HTTPServer
|
||||
|
||||
try:
|
||||
from tornado.netutil import bind_unix_socket
|
||||
|
||||
UNIX_AVAILABLE = True
|
||||
except ImportError:
|
||||
UNIX_AVAILABLE = False
|
||||
|
||||
from telegram import Update
|
||||
from telegram._utils.logging import get_logger
|
||||
from telegram.ext._extbot import ExtBot
|
||||
|
@ -50,21 +58,34 @@ class WebhookServer:
|
|||
"is_running",
|
||||
"_server_lock",
|
||||
"_shutdown_lock",
|
||||
"unix",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext]
|
||||
self,
|
||||
listen: str,
|
||||
port: int,
|
||||
webhook_app: "WebhookAppClass",
|
||||
ssl_ctx: Optional[SSLContext],
|
||||
unix: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
if unix and not UNIX_AVAILABLE:
|
||||
raise RuntimeError("This OS does not support binding unix sockets.")
|
||||
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
|
||||
self.listen = listen
|
||||
self.port = port
|
||||
self.is_running = False
|
||||
self.unix = unix
|
||||
self._server_lock = asyncio.Lock()
|
||||
self._shutdown_lock = asyncio.Lock()
|
||||
|
||||
async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
|
||||
async with self._server_lock:
|
||||
self._http_server.listen(self.port, address=self.listen)
|
||||
if self.unix:
|
||||
socket = bind_unix_socket(str(self.unix))
|
||||
self._http_server.add_socket(socket)
|
||||
else:
|
||||
self._http_server.listen(self.port, address=self.listen)
|
||||
|
||||
self.is_running = True
|
||||
if ready is not None:
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Lesser Public License
|
||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, Response
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
|
||||
from telegram._utils.defaultvalue import DEFAULT_NONE
|
||||
from telegram._utils.types import ODVInput
|
||||
|
@ -90,6 +91,7 @@ async def send_webhook_message(
|
|||
content_type: str = "application/json",
|
||||
get_method: Optional[str] = None,
|
||||
secret_token: Optional[str] = None,
|
||||
unix: Optional[Path] = None,
|
||||
) -> Response:
|
||||
headers = {
|
||||
"content-type": content_type,
|
||||
|
@ -111,7 +113,9 @@ async def send_webhook_message(
|
|||
|
||||
url = f"http://{ip}:{port}/{url_path}"
|
||||
|
||||
async with AsyncClient() as client:
|
||||
transport = AsyncHTTPTransport(uds=unix) if unix else None
|
||||
|
||||
async with AsyncClient(transport=transport) as client:
|
||||
return await client.request(
|
||||
url=url, method=get_method or "POST", data=payload, headers=headers
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||
import asyncio
|
||||
import logging
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
@ -32,7 +33,7 @@ from telegram.ext import ExtBot, InvalidCallbackData, Updater
|
|||
from telegram.request import HTTPXRequest
|
||||
from tests.auxil.build_messages import make_message, make_message_update
|
||||
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
|
||||
from tests.auxil.files import data_file
|
||||
from tests.auxil.files import TEST_DATA_PATH, data_file
|
||||
from tests.auxil.networking import send_webhook_message
|
||||
from tests.auxil.pytest_classes import PytestBot, make_bot
|
||||
from tests.auxil.slots import mro_slots
|
||||
|
@ -74,6 +75,14 @@ class TestUpdater:
|
|||
self.cb_handler_called = None
|
||||
self.test_flag = False
|
||||
|
||||
# This is needed instead of pytest's temp_path because the file path gets too long on macOS
|
||||
# otherwise
|
||||
@pytest.fixture()
|
||||
def file_path(self) -> str:
|
||||
path = TEST_DATA_PATH / "test.sock"
|
||||
yield str(path)
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
def error_callback(self, error):
|
||||
self.received = error
|
||||
self.err_handler_called.set()
|
||||
|
@ -680,9 +689,13 @@ class TestUpdater:
|
|||
@pytest.mark.parametrize("ext_bot", [True, False])
|
||||
@pytest.mark.parametrize("drop_pending_updates", [True, False])
|
||||
@pytest.mark.parametrize("secret_token", ["SecretToken", None])
|
||||
@pytest.mark.parametrize("unix", [None, True])
|
||||
async def test_webhook_basic(
|
||||
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token
|
||||
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path
|
||||
):
|
||||
# Skipping unix test on windows since they fail
|
||||
if unix and platform.system() == "Windows":
|
||||
pytest.skip("Windows doesn't support unix bind")
|
||||
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
|
||||
# that depends on this distinction works
|
||||
if ext_bot and not isinstance(updater.bot, ExtBot):
|
||||
|
@ -706,34 +719,70 @@ class TestUpdater:
|
|||
port = randrange(1024, 49152) # Select random port
|
||||
|
||||
async with updater:
|
||||
return_value = await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
secret_token=secret_token,
|
||||
)
|
||||
if unix:
|
||||
return_value = await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
secret_token=secret_token,
|
||||
url_path="TOKEN",
|
||||
unix=file_path,
|
||||
webhook_url="string",
|
||||
)
|
||||
else:
|
||||
return_value = await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
secret_token=secret_token,
|
||||
webhook_url="string",
|
||||
)
|
||||
assert return_value is updater.update_queue
|
||||
assert updater.running
|
||||
|
||||
# Now, we send an update to the server
|
||||
update = make_message_update("Webhook")
|
||||
await send_webhook_message(
|
||||
ip, port, update.to_json(), "TOKEN", secret_token=secret_token
|
||||
ip,
|
||||
port,
|
||||
update.to_json(),
|
||||
"TOKEN",
|
||||
secret_token=secret_token,
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
||||
|
||||
# Returns Not Found if path is incorrect
|
||||
response = await send_webhook_message(ip, port, "123456", "webhook_handler.py")
|
||||
response = await send_webhook_message(
|
||||
ip,
|
||||
port,
|
||||
"123456",
|
||||
"webhook_handler.py",
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
assert response.status_code == HTTPStatus.NOT_FOUND
|
||||
|
||||
# Returns METHOD_NOT_ALLOWED if method is not allowed
|
||||
response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD")
|
||||
response = await send_webhook_message(
|
||||
ip,
|
||||
port,
|
||||
None,
|
||||
"TOKEN",
|
||||
get_method="HEAD",
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
|
||||
|
||||
if secret_token:
|
||||
# Returns Forbidden if no secret token is set
|
||||
response = await send_webhook_message(ip, port, update.to_json(), "TOKEN")
|
||||
|
||||
response = await send_webhook_message(
|
||||
ip,
|
||||
port,
|
||||
update.to_json(),
|
||||
"TOKEN",
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert response.text == self.response_text.format(
|
||||
"Request did not include the secret token", HTTPStatus.FORBIDDEN
|
||||
|
@ -741,7 +790,12 @@ class TestUpdater:
|
|||
|
||||
# Returns Forbidden if the secret token is wrong
|
||||
response = await send_webhook_message(
|
||||
ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken"
|
||||
ip,
|
||||
port,
|
||||
update.to_json(),
|
||||
"TOKEN",
|
||||
secret_token="NotTheSecretToken",
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert response.text == self.response_text.format(
|
||||
|
@ -757,19 +811,54 @@ class TestUpdater:
|
|||
assert self.message_count == 0
|
||||
|
||||
# We call the same logic twice to make sure that restarting the updater works as well
|
||||
await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
)
|
||||
if unix:
|
||||
await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
secret_token=secret_token,
|
||||
unix=file_path,
|
||||
webhook_url="string",
|
||||
)
|
||||
else:
|
||||
await updater.start_webhook(
|
||||
drop_pending_updates=drop_pending_updates,
|
||||
ip_address=ip,
|
||||
port=port,
|
||||
url_path="TOKEN",
|
||||
secret_token=secret_token,
|
||||
webhook_url="string",
|
||||
)
|
||||
assert updater.running
|
||||
update = make_message_update("Webhook")
|
||||
await send_webhook_message(ip, port, update.to_json(), "TOKEN")
|
||||
await send_webhook_message(
|
||||
ip,
|
||||
port,
|
||||
update.to_json(),
|
||||
"" if unix else "TOKEN",
|
||||
secret_token=secret_token,
|
||||
unix=file_path if unix else None,
|
||||
)
|
||||
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
||||
await updater.stop()
|
||||
assert not updater.running
|
||||
|
||||
async def test_unix_webhook_mutually_exclusive_params(self, updater):
|
||||
async with updater:
|
||||
with pytest.raises(RuntimeError, match="You can not pass unix and listen"):
|
||||
await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter")
|
||||
with pytest.raises(RuntimeError, match="You can not pass unix and port"):
|
||||
await updater.start_webhook(port=20, unix="DoesntMatter")
|
||||
with pytest.raises(RuntimeError, match="you set unix, you also need to set the URL"):
|
||||
await updater.start_webhook(unix="DoesntMatter")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() != "Windows",
|
||||
reason="Windows is the only platform without unix",
|
||||
)
|
||||
async def test_no_unix(self, updater):
|
||||
async with updater:
|
||||
with pytest.raises(RuntimeError, match="binding unix sockets."):
|
||||
await updater.start_webhook(unix="DoesntMatter", webhook_url="TOKEN")
|
||||
|
||||
async def test_start_webhook_already_running(self, updater, monkeypatch):
|
||||
async def return_true(*args, **kwargs):
|
||||
return True
|
||||
|
|
Loading…
Reference in a new issue