mirror of
https://github.com/python-telegram-bot/python-telegram-bot.git
synced 2025-01-03 09:49:21 +01:00
Add Support for Unix Sockets to Updater.start_webhook
(#3986)
Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
parent
cc45f49a4f
commit
2345bfbb53
6 changed files with 213 additions and 38 deletions
|
@ -132,3 +132,19 @@ DEFAULT_TRUE: DefaultValue[bool] = DefaultValue(True)
|
||||||
|
|
||||||
.. versionadded:: 20.0
|
.. versionadded:: 20.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_20: DefaultValue[int] = DefaultValue(20)
|
||||||
|
""":class:`DefaultValue`: Default :obj:`20`"""
|
||||||
|
|
||||||
|
DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1")
|
||||||
|
""":class:`DefaultValue`: Default :obj:`127.0.0.1`
|
||||||
|
|
||||||
|
.. versionadded:: NEXT.VERSION
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_80: DefaultValue[int] = DefaultValue(80)
|
||||||
|
""":class:`DefaultValue`: Default :obj:`80`
|
||||||
|
|
||||||
|
.. versionadded:: NEXT.VERSION
|
||||||
|
"""
|
||||||
|
|
|
@ -52,7 +52,13 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from telegram._update import Update
|
from telegram._update import Update
|
||||||
from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue
|
from telegram._utils.defaultvalue import (
|
||||||
|
DEFAULT_80,
|
||||||
|
DEFAULT_IP,
|
||||||
|
DEFAULT_NONE,
|
||||||
|
DEFAULT_TRUE,
|
||||||
|
DefaultValue,
|
||||||
|
)
|
||||||
from telegram._utils.logging import get_logger
|
from telegram._utils.logging import get_logger
|
||||||
from telegram._utils.repr import build_repr_with_selected_attrs
|
from telegram._utils.repr import build_repr_with_selected_attrs
|
||||||
from telegram._utils.types import SCT, DVType, ODVInput
|
from telegram._utils.types import SCT, DVType, ODVInput
|
||||||
|
@ -834,8 +840,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
||||||
|
|
||||||
def run_webhook(
|
def run_webhook(
|
||||||
self,
|
self,
|
||||||
listen: str = "127.0.0.1",
|
listen: DVType[str] = DEFAULT_IP,
|
||||||
port: int = 80,
|
port: DVType[int] = DEFAULT_80,
|
||||||
url_path: str = "",
|
url_path: str = "",
|
||||||
cert: Optional[Union[str, Path]] = None,
|
cert: Optional[Union[str, Path]] = None,
|
||||||
key: Optional[Union[str, Path]] = None,
|
key: Optional[Union[str, Path]] = None,
|
||||||
|
@ -848,6 +854,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
||||||
close_loop: bool = True,
|
close_loop: bool = True,
|
||||||
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
|
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
|
||||||
secret_token: Optional[str] = None,
|
secret_token: Optional[str] = None,
|
||||||
|
unix: Optional[Union[str, Path]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Convenience method that takes care of initializing and starting the app,
|
"""Convenience method that takes care of initializing and starting the app,
|
||||||
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
|
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
|
||||||
|
@ -940,6 +947,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
||||||
header isn't set or it is set to a wrong token.
|
header isn't set or it is set to a wrong token.
|
||||||
|
|
||||||
.. versionadded:: 20.0
|
.. versionadded:: 20.0
|
||||||
|
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
|
||||||
|
does not need to exist, in which case the file will be created.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
This parameter is a replacement for the default TCP bind. Therefore, it is
|
||||||
|
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
|
||||||
|
this param, you must also run a reverse proxy to the unix socket and set the
|
||||||
|
appropriate :paramref:`webhook_url`.
|
||||||
|
|
||||||
|
.. versionadded:: NEXT.VERSION
|
||||||
"""
|
"""
|
||||||
if not self.updater:
|
if not self.updater:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -960,6 +977,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
max_connections=max_connections,
|
max_connections=max_connections,
|
||||||
secret_token=secret_token,
|
secret_token=secret_token,
|
||||||
|
unix=unix,
|
||||||
),
|
),
|
||||||
close_loop=close_loop,
|
close_loop=close_loop,
|
||||||
stop_signals=stop_signals,
|
stop_signals=stop_signals,
|
||||||
|
|
|
@ -35,10 +35,10 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from telegram._utils.defaultvalue import DEFAULT_NONE
|
from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue
|
||||||
from telegram._utils.logging import get_logger
|
from telegram._utils.logging import get_logger
|
||||||
from telegram._utils.repr import build_repr_with_selected_attrs
|
from telegram._utils.repr import build_repr_with_selected_attrs
|
||||||
from telegram._utils.types import ODVInput
|
from telegram._utils.types import DVType, ODVInput
|
||||||
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
|
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -456,8 +456,8 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
|
|
||||||
async def start_webhook(
|
async def start_webhook(
|
||||||
self,
|
self,
|
||||||
listen: str = "127.0.0.1",
|
listen: DVType[str] = DEFAULT_IP,
|
||||||
port: int = 80,
|
port: DVType[int] = DEFAULT_80,
|
||||||
url_path: str = "",
|
url_path: str = "",
|
||||||
cert: Optional[Union[str, Path]] = None,
|
cert: Optional[Union[str, Path]] = None,
|
||||||
key: Optional[Union[str, Path]] = None,
|
key: Optional[Union[str, Path]] = None,
|
||||||
|
@ -468,6 +468,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
max_connections: int = 40,
|
max_connections: int = 40,
|
||||||
secret_token: Optional[str] = None,
|
secret_token: Optional[str] = None,
|
||||||
|
unix: Optional[Union[str, Path]] = None,
|
||||||
) -> "asyncio.Queue[object]":
|
) -> "asyncio.Queue[object]":
|
||||||
"""
|
"""
|
||||||
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
|
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
|
||||||
|
@ -536,6 +537,16 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
header isn't set or it is set to a wrong token.
|
header isn't set or it is set to a wrong token.
|
||||||
|
|
||||||
.. versionadded:: 20.0
|
.. versionadded:: 20.0
|
||||||
|
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
|
||||||
|
does not need to exist, in which case the file will be created.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
This parameter is a replacement for the default TCP bind. Therefore, it is
|
||||||
|
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
|
||||||
|
this param, you must also run a reverse proxy to the unix socket and set the
|
||||||
|
appropriate :paramref:`webhook_url`.
|
||||||
|
|
||||||
|
.. versionadded:: NEXT.VERSION
|
||||||
Returns:
|
Returns:
|
||||||
:class:`queue.Queue`: The update queue that can be filled from the main thread.
|
:class:`queue.Queue`: The update queue that can be filled from the main thread.
|
||||||
|
|
||||||
|
@ -547,6 +558,21 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
"To use `start_webhook`, PTB must be installed via `pip install "
|
"To use `start_webhook`, PTB must be installed via `pip install "
|
||||||
'"python-telegram-bot[webhooks]"`.'
|
'"python-telegram-bot[webhooks]"`.'
|
||||||
)
|
)
|
||||||
|
# unix has special requirements what must and mustn't be set when using it
|
||||||
|
if unix:
|
||||||
|
error_msg = (
|
||||||
|
"You can not pass unix and {0}, only use one. Unix if you want to "
|
||||||
|
"initialize a unix socket, or {0} for a standard TCP server."
|
||||||
|
)
|
||||||
|
if not isinstance(listen, DefaultValue):
|
||||||
|
raise RuntimeError(error_msg.format("listen"))
|
||||||
|
if not isinstance(port, DefaultValue):
|
||||||
|
raise RuntimeError(error_msg.format("port"))
|
||||||
|
if not webhook_url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Since you set unix, you also need to set the URL to the webhook "
|
||||||
|
"of the proxy you run in front of the unix socket."
|
||||||
|
)
|
||||||
|
|
||||||
async with self.__lock:
|
async with self.__lock:
|
||||||
if self.running:
|
if self.running:
|
||||||
|
@ -561,8 +587,8 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
webhook_ready = asyncio.Event()
|
webhook_ready = asyncio.Event()
|
||||||
|
|
||||||
await self._start_webhook(
|
await self._start_webhook(
|
||||||
listen=listen,
|
listen=DefaultValue.get_value(listen),
|
||||||
port=port,
|
port=DefaultValue.get_value(port),
|
||||||
url_path=url_path,
|
url_path=url_path,
|
||||||
cert=cert,
|
cert=cert,
|
||||||
key=key,
|
key=key,
|
||||||
|
@ -574,6 +600,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
max_connections=max_connections,
|
max_connections=max_connections,
|
||||||
secret_token=secret_token,
|
secret_token=secret_token,
|
||||||
|
unix=unix,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug("Waiting for webhook server to start")
|
_LOGGER.debug("Waiting for webhook server to start")
|
||||||
|
@ -601,6 +628,7 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
max_connections: int = 40,
|
max_connections: int = 40,
|
||||||
secret_token: Optional[str] = None,
|
secret_token: Optional[str] = None,
|
||||||
|
unix: Optional[Union[str, Path]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
_LOGGER.debug("Updater thread started (webhook)")
|
_LOGGER.debug("Updater thread started (webhook)")
|
||||||
|
|
||||||
|
@ -625,14 +653,13 @@ class Updater(AsyncContextManager["Updater"]):
|
||||||
raise TelegramError("Invalid SSL Certificate") from exc
|
raise TelegramError("Invalid SSL Certificate") from exc
|
||||||
else:
|
else:
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
|
|
||||||
# Create and start server
|
# Create and start server
|
||||||
self._httpd = WebhookServer(listen, port, app, ssl_ctx)
|
self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)
|
||||||
|
|
||||||
if not webhook_url:
|
if not webhook_url:
|
||||||
webhook_url = self._gen_webhook_url(
|
webhook_url = self._gen_webhook_url(
|
||||||
protocol="https" if ssl_ctx else "http",
|
protocol="https" if ssl_ctx else "http",
|
||||||
listen=listen,
|
listen=DefaultValue.get_value(listen),
|
||||||
port=port,
|
port=port,
|
||||||
url_path=url_path,
|
url_path=url_path,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,15 +20,23 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from pathlib import Path
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type, Union
|
||||||
|
|
||||||
# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
|
# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
|
||||||
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
|
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
|
||||||
import tornado.web
|
import tornado.web
|
||||||
from tornado.httpserver import HTTPServer
|
from tornado.httpserver import HTTPServer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tornado.netutil import bind_unix_socket
|
||||||
|
|
||||||
|
UNIX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
UNIX_AVAILABLE = False
|
||||||
|
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
from telegram._utils.logging import get_logger
|
from telegram._utils.logging import get_logger
|
||||||
from telegram.ext._extbot import ExtBot
|
from telegram.ext._extbot import ExtBot
|
||||||
|
@ -50,20 +58,33 @@ class WebhookServer:
|
||||||
"is_running",
|
"is_running",
|
||||||
"_server_lock",
|
"_server_lock",
|
||||||
"_shutdown_lock",
|
"_shutdown_lock",
|
||||||
|
"unix",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext]
|
self,
|
||||||
|
listen: str,
|
||||||
|
port: int,
|
||||||
|
webhook_app: "WebhookAppClass",
|
||||||
|
ssl_ctx: Optional[SSLContext],
|
||||||
|
unix: Optional[Union[str, Path]] = None,
|
||||||
):
|
):
|
||||||
|
if unix and not UNIX_AVAILABLE:
|
||||||
|
raise RuntimeError("This OS does not support binding unix sockets.")
|
||||||
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
|
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
|
||||||
self.listen = listen
|
self.listen = listen
|
||||||
self.port = port
|
self.port = port
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
self.unix = unix
|
||||||
self._server_lock = asyncio.Lock()
|
self._server_lock = asyncio.Lock()
|
||||||
self._shutdown_lock = asyncio.Lock()
|
self._shutdown_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
|
async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
|
||||||
async with self._server_lock:
|
async with self._server_lock:
|
||||||
|
if self.unix:
|
||||||
|
socket = bind_unix_socket(str(self.unix))
|
||||||
|
self._http_server.add_socket(socket)
|
||||||
|
else:
|
||||||
self._http_server.listen(self.port, address=self.listen)
|
self._http_server.listen(self.port, address=self.listen)
|
||||||
|
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
|
|
|
@ -16,10 +16,11 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Lesser Public License
|
# You should have received a copy of the GNU Lesser Public License
|
||||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient, Response
|
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||||
|
|
||||||
from telegram._utils.defaultvalue import DEFAULT_NONE
|
from telegram._utils.defaultvalue import DEFAULT_NONE
|
||||||
from telegram._utils.types import ODVInput
|
from telegram._utils.types import ODVInput
|
||||||
|
@ -90,6 +91,7 @@ async def send_webhook_message(
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
get_method: Optional[str] = None,
|
get_method: Optional[str] = None,
|
||||||
secret_token: Optional[str] = None,
|
secret_token: Optional[str] = None,
|
||||||
|
unix: Optional[Path] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": content_type,
|
"content-type": content_type,
|
||||||
|
@ -111,7 +113,9 @@ async def send_webhook_message(
|
||||||
|
|
||||||
url = f"http://{ip}:{port}/{url_path}"
|
url = f"http://{ip}:{port}/{url_path}"
|
||||||
|
|
||||||
async with AsyncClient() as client:
|
transport = AsyncHTTPTransport(uds=unix) if unix else None
|
||||||
|
|
||||||
|
async with AsyncClient(transport=transport) as client:
|
||||||
return await client.request(
|
return await client.request(
|
||||||
url=url, method=get_method or "POST", data=payload, headers=headers
|
url=url, method=get_method or "POST", data=payload, headers=headers
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
# along with this program. If not, see [http://www.gnu.org/licenses/].
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -32,7 +33,7 @@ from telegram.ext import ExtBot, InvalidCallbackData, Updater
|
||||||
from telegram.request import HTTPXRequest
|
from telegram.request import HTTPXRequest
|
||||||
from tests.auxil.build_messages import make_message, make_message_update
|
from tests.auxil.build_messages import make_message, make_message_update
|
||||||
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
|
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
|
||||||
from tests.auxil.files import data_file
|
from tests.auxil.files import TEST_DATA_PATH, data_file
|
||||||
from tests.auxil.networking import send_webhook_message
|
from tests.auxil.networking import send_webhook_message
|
||||||
from tests.auxil.pytest_classes import PytestBot, make_bot
|
from tests.auxil.pytest_classes import PytestBot, make_bot
|
||||||
from tests.auxil.slots import mro_slots
|
from tests.auxil.slots import mro_slots
|
||||||
|
@ -74,6 +75,14 @@ class TestUpdater:
|
||||||
self.cb_handler_called = None
|
self.cb_handler_called = None
|
||||||
self.test_flag = False
|
self.test_flag = False
|
||||||
|
|
||||||
|
# This is needed instead of pytest's temp_path because the file path gets too long on macOS
|
||||||
|
# otherwise
|
||||||
|
@pytest.fixture()
|
||||||
|
def file_path(self) -> str:
|
||||||
|
path = TEST_DATA_PATH / "test.sock"
|
||||||
|
yield str(path)
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
|
||||||
def error_callback(self, error):
|
def error_callback(self, error):
|
||||||
self.received = error
|
self.received = error
|
||||||
self.err_handler_called.set()
|
self.err_handler_called.set()
|
||||||
|
@ -680,9 +689,13 @@ class TestUpdater:
|
||||||
@pytest.mark.parametrize("ext_bot", [True, False])
|
@pytest.mark.parametrize("ext_bot", [True, False])
|
||||||
@pytest.mark.parametrize("drop_pending_updates", [True, False])
|
@pytest.mark.parametrize("drop_pending_updates", [True, False])
|
||||||
@pytest.mark.parametrize("secret_token", ["SecretToken", None])
|
@pytest.mark.parametrize("secret_token", ["SecretToken", None])
|
||||||
|
@pytest.mark.parametrize("unix", [None, True])
|
||||||
async def test_webhook_basic(
|
async def test_webhook_basic(
|
||||||
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token
|
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path
|
||||||
):
|
):
|
||||||
|
# Skipping unix test on windows since they fail
|
||||||
|
if unix and platform.system() == "Windows":
|
||||||
|
pytest.skip("Windows doesn't support unix bind")
|
||||||
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
|
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
|
||||||
# that depends on this distinction works
|
# that depends on this distinction works
|
||||||
if ext_bot and not isinstance(updater.bot, ExtBot):
|
if ext_bot and not isinstance(updater.bot, ExtBot):
|
||||||
|
@ -706,12 +719,22 @@ class TestUpdater:
|
||||||
port = randrange(1024, 49152) # Select random port
|
port = randrange(1024, 49152) # Select random port
|
||||||
|
|
||||||
async with updater:
|
async with updater:
|
||||||
|
if unix:
|
||||||
|
return_value = await updater.start_webhook(
|
||||||
|
drop_pending_updates=drop_pending_updates,
|
||||||
|
secret_token=secret_token,
|
||||||
|
url_path="TOKEN",
|
||||||
|
unix=file_path,
|
||||||
|
webhook_url="string",
|
||||||
|
)
|
||||||
|
else:
|
||||||
return_value = await updater.start_webhook(
|
return_value = await updater.start_webhook(
|
||||||
drop_pending_updates=drop_pending_updates,
|
drop_pending_updates=drop_pending_updates,
|
||||||
ip_address=ip,
|
ip_address=ip,
|
||||||
port=port,
|
port=port,
|
||||||
url_path="TOKEN",
|
url_path="TOKEN",
|
||||||
secret_token=secret_token,
|
secret_token=secret_token,
|
||||||
|
webhook_url="string",
|
||||||
)
|
)
|
||||||
assert return_value is updater.update_queue
|
assert return_value is updater.update_queue
|
||||||
assert updater.running
|
assert updater.running
|
||||||
|
@ -719,21 +742,47 @@ class TestUpdater:
|
||||||
# Now, we send an update to the server
|
# Now, we send an update to the server
|
||||||
update = make_message_update("Webhook")
|
update = make_message_update("Webhook")
|
||||||
await send_webhook_message(
|
await send_webhook_message(
|
||||||
ip, port, update.to_json(), "TOKEN", secret_token=secret_token
|
ip,
|
||||||
|
port,
|
||||||
|
update.to_json(),
|
||||||
|
"TOKEN",
|
||||||
|
secret_token=secret_token,
|
||||||
|
unix=file_path if unix else None,
|
||||||
)
|
)
|
||||||
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
||||||
|
|
||||||
# Returns Not Found if path is incorrect
|
# Returns Not Found if path is incorrect
|
||||||
response = await send_webhook_message(ip, port, "123456", "webhook_handler.py")
|
response = await send_webhook_message(
|
||||||
|
ip,
|
||||||
|
port,
|
||||||
|
"123456",
|
||||||
|
"webhook_handler.py",
|
||||||
|
unix=file_path if unix else None,
|
||||||
|
)
|
||||||
assert response.status_code == HTTPStatus.NOT_FOUND
|
assert response.status_code == HTTPStatus.NOT_FOUND
|
||||||
|
|
||||||
# Returns METHOD_NOT_ALLOWED if method is not allowed
|
# Returns METHOD_NOT_ALLOWED if method is not allowed
|
||||||
response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD")
|
response = await send_webhook_message(
|
||||||
|
ip,
|
||||||
|
port,
|
||||||
|
None,
|
||||||
|
"TOKEN",
|
||||||
|
get_method="HEAD",
|
||||||
|
unix=file_path if unix else None,
|
||||||
|
)
|
||||||
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
|
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
|
||||||
|
|
||||||
if secret_token:
|
if secret_token:
|
||||||
# Returns Forbidden if no secret token is set
|
# Returns Forbidden if no secret token is set
|
||||||
response = await send_webhook_message(ip, port, update.to_json(), "TOKEN")
|
|
||||||
|
response = await send_webhook_message(
|
||||||
|
ip,
|
||||||
|
port,
|
||||||
|
update.to_json(),
|
||||||
|
"TOKEN",
|
||||||
|
unix=file_path if unix else None,
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||||
assert response.text == self.response_text.format(
|
assert response.text == self.response_text.format(
|
||||||
"Request did not include the secret token", HTTPStatus.FORBIDDEN
|
"Request did not include the secret token", HTTPStatus.FORBIDDEN
|
||||||
|
@ -741,7 +790,12 @@ class TestUpdater:
|
||||||
|
|
||||||
# Returns Forbidden if the secret token is wrong
|
# Returns Forbidden if the secret token is wrong
|
||||||
response = await send_webhook_message(
|
response = await send_webhook_message(
|
||||||
ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken"
|
ip,
|
||||||
|
port,
|
||||||
|
update.to_json(),
|
||||||
|
"TOKEN",
|
||||||
|
secret_token="NotTheSecretToken",
|
||||||
|
unix=file_path if unix else None,
|
||||||
)
|
)
|
||||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||||
assert response.text == self.response_text.format(
|
assert response.text == self.response_text.format(
|
||||||
|
@ -757,19 +811,54 @@ class TestUpdater:
|
||||||
assert self.message_count == 0
|
assert self.message_count == 0
|
||||||
|
|
||||||
# We call the same logic twice to make sure that restarting the updater works as well
|
# We call the same logic twice to make sure that restarting the updater works as well
|
||||||
|
if unix:
|
||||||
|
await updater.start_webhook(
|
||||||
|
drop_pending_updates=drop_pending_updates,
|
||||||
|
secret_token=secret_token,
|
||||||
|
unix=file_path,
|
||||||
|
webhook_url="string",
|
||||||
|
)
|
||||||
|
else:
|
||||||
await updater.start_webhook(
|
await updater.start_webhook(
|
||||||
drop_pending_updates=drop_pending_updates,
|
drop_pending_updates=drop_pending_updates,
|
||||||
ip_address=ip,
|
ip_address=ip,
|
||||||
port=port,
|
port=port,
|
||||||
url_path="TOKEN",
|
url_path="TOKEN",
|
||||||
|
secret_token=secret_token,
|
||||||
|
webhook_url="string",
|
||||||
)
|
)
|
||||||
assert updater.running
|
assert updater.running
|
||||||
update = make_message_update("Webhook")
|
update = make_message_update("Webhook")
|
||||||
await send_webhook_message(ip, port, update.to_json(), "TOKEN")
|
await send_webhook_message(
|
||||||
|
ip,
|
||||||
|
port,
|
||||||
|
update.to_json(),
|
||||||
|
"" if unix else "TOKEN",
|
||||||
|
secret_token=secret_token,
|
||||||
|
unix=file_path if unix else None,
|
||||||
|
)
|
||||||
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
|
||||||
await updater.stop()
|
await updater.stop()
|
||||||
assert not updater.running
|
assert not updater.running
|
||||||
|
|
||||||
|
async def test_unix_webhook_mutually_exclusive_params(self, updater):
|
||||||
|
async with updater:
|
||||||
|
with pytest.raises(RuntimeError, match="You can not pass unix and listen"):
|
||||||
|
await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter")
|
||||||
|
with pytest.raises(RuntimeError, match="You can not pass unix and port"):
|
||||||
|
await updater.start_webhook(port=20, unix="DoesntMatter")
|
||||||
|
with pytest.raises(RuntimeError, match="you set unix, you also need to set the URL"):
|
||||||
|
await updater.start_webhook(unix="DoesntMatter")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
platform.system() != "Windows",
|
||||||
|
reason="Windows is the only platform without unix",
|
||||||
|
)
|
||||||
|
async def test_no_unix(self, updater):
|
||||||
|
async with updater:
|
||||||
|
with pytest.raises(RuntimeError, match="binding unix sockets."):
|
||||||
|
await updater.start_webhook(unix="DoesntMatter", webhook_url="TOKEN")
|
||||||
|
|
||||||
async def test_start_webhook_already_running(self, updater, monkeypatch):
|
async def test_start_webhook_already_running(self, updater, monkeypatch):
|
||||||
async def return_true(*args, **kwargs):
|
async def return_true(*args, **kwargs):
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in a new issue