Don't Set Signal Handlers On Windows By Default (#3065)

This commit is contained in:
Poolitzer 2022-06-02 09:43:03 +02:00 committed by GitHub
parent 42955ecddf
commit 306cc64170
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 8 deletions

View file

@ -12,6 +12,7 @@ isort==5.10.1
pytest==7.1.2 pytest==7.1.2
pytest-asyncio==0.18.3 pytest-asyncio==0.18.3
pytest-timeout==2.1.0 # used to timeout tests
flaky # Used for flaky tests (flaky decorator) flaky # Used for flaky tests (flaky decorator)
beautifulsoup4 # used in test_official for parsing tg docs beautifulsoup4 # used in test_official for parsing tg docs

View file

@ -21,6 +21,7 @@ import asyncio
import inspect import inspect
import itertools import itertools
import logging import logging
import platform
import signal import signal
from collections import defaultdict from collections import defaultdict
from contextlib import AbstractAsyncContextManager from contextlib import AbstractAsyncContextManager
@ -547,7 +548,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
allowed_updates: List[str] = None, allowed_updates: List[str] = None,
drop_pending_updates: bool = None, drop_pending_updates: bool = None,
close_loop: bool = True, close_loop: bool = True,
stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
) -> None: ) -> None:
"""Convenience method that takes care of initializing and starting the app, """Convenience method that takes care of initializing and starting the app,
polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and
@ -596,7 +597,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut
down the app. Pass :obj:`None` to not use stop signals. down the app. Pass :obj:`None` to not use stop signals.
Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and
:data:`signal.SIGABRT`. :data:`signal.SIGABRT` on non Windows platforms.
Caution: Caution:
Not every :class:`asyncio.AbstractEventLoop` implements Not every :class:`asyncio.AbstractEventLoop` implements
@ -646,7 +647,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
ip_address: str = None, ip_address: str = None,
max_connections: int = 40, max_connections: int = 40,
close_loop: bool = True, close_loop: bool = True,
stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
) -> None: ) -> None:
"""Convenience method that takes care of initializing and starting the app, """Convenience method that takes care of initializing and starting the app,
polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
@ -736,7 +737,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
def __run( def __run(
self, self,
updater_coroutine: Coroutine, updater_coroutine: Coroutine,
stop_signals: Optional[Sequence[int]], stop_signals: ODVInput[Sequence[int]],
close_loop: bool = True, close_loop: bool = True,
) -> None: ) -> None:
# Calling get_event_loop() should still be okay even in py3.10+ as long as there is a # Calling get_event_loop() should still be okay even in py3.10+ as long as there is a
@ -744,7 +745,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager)
# See the docs of get_event_loop() and get_running_loop() for more info # See the docs of get_event_loop() and get_running_loop() for more info
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if stop_signals is DEFAULT_NONE and platform.system() != "Windows":
stop_signals = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT)
try: try:
if not isinstance(stop_signals, DefaultValue):
for sig in stop_signals or []: for sig in stop_signals or []:
loop.add_signal_handler(sig, self._raise_system_exit) loop.add_signal_handler(sig, self._raise_system_exit)
except NotImplementedError as exc: except NotImplementedError as exc:

View file

@ -1659,9 +1659,9 @@ class TestApplication:
with pytest.raises(RuntimeError, match="Prevent Actually Running"): with pytest.raises(RuntimeError, match="Prevent Actually Running"):
if "polling" in method: if "polling" in method:
app.run_polling(close_loop=False) app.run_polling(close_loop=False, stop_signals=(signal.SIGINT,))
else: else:
app.run_webhook(close_loop=False) app.run_webhook(close_loop=False, stop_signals=(signal.SIGTERM,))
assert len(recwarn) >= 1 assert len(recwarn) >= 1
found = False found = False
@ -1680,3 +1680,39 @@ class TestApplication:
app.run_webhook(close_loop=False, stop_signals=None) app.run_webhook(close_loop=False, stop_signals=None)
assert len(recwarn) == 0 assert len(recwarn) == 0
@pytest.mark.timeout(6)
def test_signal_handlers(self, app, monkeypatch):
# this test should make sure that signal handlers are set by default on Linux + Mac,
# and not on Windows.
received_signals = []
def signal_handler_test(*args, **kwargs):
# args[0] is the signal, [1] the callback
received_signals.append(args[0])
loop = asyncio.get_event_loop()
monkeypatch.setattr(loop, "add_signal_handler", signal_handler_test)
async def abort_app():
await asyncio.sleep(2)
raise SystemExit
loop.create_task(abort_app())
app.run_polling(close_loop=False)
if platform.system() == "Windows":
assert received_signals == []
else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]
received_signals.clear()
loop.create_task(abort_app())
app.run_webhook(port=49152, webhook_url="example.com", close_loop=False)
if platform.system() == "Windows":
assert received_signals == []
else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]