diff --git a/requirements-dev.txt b/requirements-dev.txt index 5cd0fa698..6e30d0d3f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,6 +12,7 @@ isort==5.10.1 pytest==7.1.2 pytest-asyncio==0.18.3 +pytest-timeout==2.1.0 # used to timeout tests flaky # Used for flaky tests (flaky decorator) beautifulsoup4 # used in test_official for parsing tg docs diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 96a332a7c..63320ec69 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -21,6 +21,7 @@ import asyncio import inspect import itertools import logging +import platform import signal from collections import defaultdict from contextlib import AbstractAsyncContextManager @@ -547,7 +548,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) allowed_updates: List[str] = None, drop_pending_updates: bool = None, close_loop: bool = True, - stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), + stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE, ) -> None: """Convenience method that takes care of initializing and starting the app, polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and @@ -596,7 +597,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut down the app. Pass :obj:`None` to not use stop signals. Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and - :data:`signal.SIGABRT`. + :data:`signal.SIGABRT` on non Windows platforms. Caution: Not every :class:`asyncio.AbstractEventLoop` implements @@ -646,7 +647,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) ip_address: str = None, max_connections: int = 40, close_loop: bool = True, - stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), + stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE, ) -> None: """Convenience method that takes care of initializing and starting the app, polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and @@ -736,7 +737,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) def __run( self, updater_coroutine: Coroutine, - stop_signals: Optional[Sequence[int]], + stop_signals: ODVInput[Sequence[int]], close_loop: bool = True, ) -> None: # Calling get_event_loop() should still be okay even in py3.10+ as long as there is a @@ -744,9 +745,13 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) # See the docs of get_event_loop() and get_running_loop() for more info loop = asyncio.get_event_loop() + if stop_signals is DEFAULT_NONE and platform.system() != "Windows": + stop_signals = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT) + try: - for sig in stop_signals or []: - loop.add_signal_handler(sig, self._raise_system_exit) + if not isinstance(stop_signals, DefaultValue): + for sig in stop_signals or []: + loop.add_signal_handler(sig, self._raise_system_exit) except NotImplementedError as exc: warn( f"Could not add signal handlers for the stop signals {stop_signals} due to " diff --git a/tests/test_application.py b/tests/test_application.py index ce7f48c13..8070bafb4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1659,9 +1659,9 @@ class TestApplication: with pytest.raises(RuntimeError, match="Prevent Actually Running"): if "polling" in method: - app.run_polling(close_loop=False) + app.run_polling(close_loop=False, stop_signals=(signal.SIGINT,)) else: - app.run_webhook(close_loop=False) + app.run_webhook(close_loop=False, stop_signals=(signal.SIGTERM,)) assert len(recwarn) >= 1 found = False @@ -1680,3 +1680,39 @@ class TestApplication: app.run_webhook(close_loop=False, stop_signals=None) assert len(recwarn) == 0 + + @pytest.mark.timeout(6) + def test_signal_handlers(self, app, monkeypatch): + # this test should make sure that signal handlers are set by default on Linux + Mac, + # and not on Windows. + + received_signals = [] + + def signal_handler_test(*args, **kwargs): + # args[0] is the signal, [1] the callback + received_signals.append(args[0]) + + loop = asyncio.get_event_loop() + monkeypatch.setattr(loop, "add_signal_handler", signal_handler_test) + + async def abort_app(): + await asyncio.sleep(2) + raise SystemExit + + loop.create_task(abort_app()) + + app.run_polling(close_loop=False) + + if platform.system() == "Windows": + assert received_signals == [] + else: + assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] + + received_signals.clear() + loop.create_task(abort_app()) + app.run_webhook(port=49152, webhook_url="example.com", close_loop=False) + + if platform.system() == "Windows": + assert received_signals == [] + else: + assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]