diff --git a/telegram/ext/_utils/stack.py b/telegram/ext/_utils/stack.py index 12b4fe87b..e7d91d230 100644 --- a/telegram/ext/_utils/stack.py +++ b/telegram/ext/_utils/stack.py @@ -25,10 +25,13 @@ Warning: user. Changes to this module are not considered breaking changes and may not be documented in the changelog. """ +import logging from pathlib import Path from types import FrameType from typing import Optional +_logger = logging.getLogger(__name__) + def was_called_by(frame: Optional[FrameType], caller: Path) -> bool: """Checks if the passed frame was called by the specified file. @@ -51,11 +54,22 @@ def was_called_by(frame: Optional[FrameType], caller: Path) -> bool: if frame is None: return False + try: + return _was_called_by(frame, caller) + except Exception as exc: + _logger.debug( + "Failed to check if frame was called by `caller`. Assuming that it was not.", + exc_info=exc, + ) + return False + + +def _was_called_by(frame: FrameType, caller: Path) -> bool: # https://stackoverflow.com/a/57712700/10606962 - if Path(frame.f_code.co_filename) == caller: + if Path(frame.f_code.co_filename).resolve() == caller: return True while frame.f_back: frame = frame.f_back - if Path(frame.f_code.co_filename) == caller: + if Path(frame.f_code.co_filename).resolve() == caller: return True return False diff --git a/tests/ext/_utils/test_stack.py b/tests/ext/_utils/test_stack.py index d82d26110..fc5d42ef5 100644 --- a/tests/ext/_utils/test_stack.py +++ b/tests/ext/_utils/test_stack.py @@ -17,19 +17,100 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import inspect +import logging +import sys from pathlib import Path +import pytest + from telegram.ext._utils.stack import was_called_by +def symlink_to(source: Path, target: Path) -> None: + """Wrapper around Path.symlink_to that pytest-skips OS Errors. + Useful e.g. for making tests not fail locally due to permission errors. + """ + try: + source.symlink_to(target) + except OSError as exc: + pytest.skip(f"Skipping due to OS error while creating symlink: {exc!r}") + + class TestStack: def test_none_input(self): assert not was_called_by(None, None) def test_called_by_current_file(self): + # Testing a call by a different file is somewhat hard but it's covered in + # TestUpdater/Application.test_manual_init_warning frame = inspect.currentframe() file = Path(__file__) assert was_called_by(frame, file) - # Testing a call by a different file is somewhat hard but it's covered in - # TestUpdater/Application.test_manual_init_warning + def test_exception(self, monkeypatch, caplog): + def resolve(self): + raise RuntimeError("Can Not Resolve") + + with caplog.at_level(logging.DEBUG): + monkeypatch.setattr(Path, "resolve", resolve) + assert not was_called_by(inspect.currentframe(), None) + + assert len(caplog.records) == 1 + assert caplog.records[0].levelno == logging.DEBUG + assert caplog.records[0].getMessage().startswith("Failed to check") + assert caplog.records[0].exc_info[0] is RuntimeError + assert "Can Not Resolve" in str(caplog.records[0].exc_info[1]) + + def test_called_by_symlink_file(self, tmp_path): + # Set up a call from a linked file in a temp directory, + # then test it with its resolved path. + # Here we expect `was_called_by` to recognize + # "`tmp_path`/caller_link.py" as same as "`tmp_path`/caller.py". + temp_file = tmp_path / "caller.py" + caller_content = """ +import inspect +def caller_func(): + return inspect.currentframe() + """ + with temp_file.open("w") as f: + f.write(caller_content) + + symlink_file = tmp_path / "caller_link.py" + symlink_to(symlink_file, temp_file) + + sys.path.append(tmp_path.as_posix()) + from caller_link import caller_func + + frame = caller_func() + assert was_called_by(frame, temp_file) + + def test_called_by_symlink_file_nested(self, tmp_path): + # Same as test_called_by_symlink_file except + # inner_func is nested inside outer_func to test + # if `was_called_by` can resolve paths in recursion. + temp_file1 = tmp_path / "inner.py" + inner_content = """ +import inspect +def inner_func(): + return inspect.currentframe() + """ + with temp_file1.open("w") as f: + f.write(inner_content) + + temp_file2 = tmp_path / "outer.py" + outer_content = """ +from inner import inner_func +def outer_func(): + return inner_func() + """ + with temp_file2.open("w") as f: + f.write(outer_content) + + symlink_file2 = tmp_path / "outer_link.py" + symlink_to(symlink_file2, temp_file2) + + sys.path.append(tmp_path.as_posix()) + from outer_link import outer_func + + frame = outer_func() + assert was_called_by(frame, temp_file2)