From e415af6c04ea8001ee4fc3a3404a5cf71712b9a8 Mon Sep 17 00:00:00 2001 From: M Bussonnier Date: Tue, 16 Jul 2024 11:00:07 +0200 Subject: [PATCH] MAINT: be stricter on str/Filename There are a bit too many places where there is a str/Filename Union. This make it hard to track issues like #346. Thus we do some refactor and enforce Filename in many places in order to make reasonning on the code easier. --- lib/python/pyflyby/_cmdline.py | 9 ++++-- lib/python/pyflyby/_file.py | 53 ++++++++++++++++++++++----------- lib/python/pyflyby/_importdb.py | 34 +++++++++++++-------- lib/python/pyflyby/_parse.py | 7 +++++ lib/python/pyflyby/_py.py | 4 +-- tests/test_autoimp.py | 20 +++++++++++-- tests/test_interactive.py | 46 +++++++++++++++++----------- tests/test_livepatch.py | 2 +- tests/test_py.py | 26 +++++----------- 9 files changed, 127 insertions(+), 74 deletions(-) diff --git a/lib/python/pyflyby/_cmdline.py b/lib/python/pyflyby/_cmdline.py index a8519b73..a0011d50 100644 --- a/lib/python/pyflyby/_cmdline.py +++ b/lib/python/pyflyby/_cmdline.py @@ -4,10 +4,10 @@ +from builtins import input import optparse import os import signal -from builtins import input import sys from textwrap import dedent import traceback @@ -285,7 +285,8 @@ def unaligned_callback(option, opt_str, value, parser): def _default_on_error(filename): raise SystemExit("bad filename %s" % (filename,)) -def filename_args(args, on_error=_default_on_error): + +def filename_args(args: List[str], on_error=_default_on_error): """ Return list of filenames given command-line arguments. @@ -293,7 +294,9 @@ def filename_args(args, on_error=_default_on_error): ``list`` of `Filename` """ if args: - return expand_py_files_from_args(args, on_error) + for a in args: + assert isinstance(a, str) + return expand_py_files_from_args([Filename(f) for f in args], on_error) elif not os.isatty(0): return [Filename.STDIN] else: diff --git a/lib/python/pyflyby/_file.py b/lib/python/pyflyby/_file.py index 9f748afd..51e4b2ba 100644 --- a/lib/python/pyflyby/_file.py +++ b/lib/python/pyflyby/_file.py @@ -3,15 +3,20 @@ # License: MIT http://opensource.org/licenses/MIT from __future__ import annotations -from functools import total_ordering, cached_property +from functools import cached_property, total_ordering import io import os import re import sys -from typing import Optional, Tuple, ClassVar +from typing import ClassVar, List, Optional, Tuple, Union from pyflyby._util import cmp, memoize +if sys.version_info < (3,10): + NoneType = type(None) +else: + from types import NoneType + class UnsafeFilenameError(ValueError): pass @@ -33,20 +38,23 @@ class Filename(object): def __new__(cls, arg): if isinstance(arg, cls): - return arg + # TODO make this assert False + return cls._from_filename(arg._filename) if isinstance(arg, str): return cls._from_filename(arg) raise TypeError @classmethod - def _from_filename(cls, filename): + def _from_filename(cls, filename: str): if not isinstance(filename, str): raise TypeError filename = str(filename) if not filename: raise UnsafeFilenameError("(empty string)") - if re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename): - raise UnsafeFilenameError(filename) + # we only allow filename with given character set + match = re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename) + if match: + raise UnsafeFilenameError((filename, match)) if re.search("(^|/)~", filename): raise UnsafeFilenameError(filename) self = object.__new__(cls) @@ -373,6 +381,8 @@ def __new__(cls, arg, filename=None, startpos=None): :rtype: ``FileText`` """ + if isinstance(filename, str): + filename = Filename(filename) if isinstance(arg, cls): if filename is startpos is None: return arg @@ -387,8 +397,8 @@ def __new__(cls, arg, filename=None, startpos=None): else: raise TypeError("%s: unexpected %s" % (cls.__name__, type(arg).__name__)) - if filename is not None: - filename = Filename(filename) + + assert isinstance(filename, (Filename, NoneType)) startpos = FilePos(startpos) self.filename = filename self.startpos = startpos @@ -439,9 +449,9 @@ def joined(self) -> str: def from_filename(cls, filename): return cls.from_lines(Filename(filename)) - def alter(self, filename=None, startpos=None): + def alter(self, filename: Optional[Filename] = None, startpos=None): if filename is not None: - filename = Filename(filename) + assert isinstance(filename, Filename) else: filename = self.filename if startpos is not None: @@ -652,8 +662,8 @@ def __hash__(self): return h -def read_file(filename): - filename = Filename(filename) +def read_file(filename: Filename) -> FileText: + assert isinstance(filename, Filename) if filename == Filename.STDIN: data = sys.stdin.read() else: @@ -661,14 +671,15 @@ def read_file(filename): data = f.read() return FileText(data, filename=filename) -def write_file(filename, data): - filename = Filename(filename) + +def write_file(filename: Filename, data): + assert isinstance(filename, Filename) data = FileText(data) with open(str(filename), 'w') as f: f.write(data.joined) -def atomic_write_file(filename, data): - filename = Filename(filename) +def atomic_write_file(filename: Filename, data): + assert isinstance(filename, Filename) data = FileText(data) temp_filename = Filename("%s.tmp.%s" % (filename, os.getpid(),)) write_file(temp_filename, data) @@ -680,7 +691,10 @@ def atomic_write_file(filename, data): pass os.rename(str(temp_filename), str(filename)) -def expand_py_files_from_args(pathnames, on_error=lambda filename: None): + +def expand_py_files_from_args( + pathnames: Union[List[Filename], Filename], on_error=lambda filename: None +): """ Enumerate ``*.py`` files, recursively. @@ -698,8 +712,11 @@ def expand_py_files_from_args(pathnames, on_error=lambda filename: None): ``list`` of `Filename` s """ if not isinstance(pathnames, (tuple, list)): + # July 2024 DeprecationWarning + # this seem to be used only internally, maybe deprecate not passing a list. pathnames = [pathnames] - pathnames = [Filename(f) for f in pathnames] + for f in pathnames: + assert isinstance(f, Filename) result = [] # Check for problematic arguments. Note that we intentionally only do # this for directly specified arguments, not for recursively traversed diff --git a/lib/python/pyflyby/_importdb.py b/lib/python/pyflyby/_importdb.py index 8e30708a..832cc803 100644 --- a/lib/python/pyflyby/_importdb.py +++ b/lib/python/pyflyby/_importdb.py @@ -9,10 +9,10 @@ import re import sys -from typing import Dict, Any, Tuple +from typing import Any, Dict, Tuple, Union, List - -from pyflyby._file import Filename, expand_py_files_from_args, UnsafeFilenameError +from pyflyby._file import (Filename, UnsafeFilenameError, + expand_py_files_from_args) from pyflyby._idents import dotted_prefixes from pyflyby._importclns import ImportMap, ImportSet from pyflyby._importstmt import Import, ImportStatement @@ -90,8 +90,11 @@ def _get_python_path(env_var_name, default_path, target_dirname): .format(env_var_name=env_var_name, p=p)) pathnames = [os.path.expanduser(p) for p in pathnames] pathnames = _expand_tripledots(pathnames, target_dirname) - pathnames = [Filename(fn) for fn in pathnames] + for fn in pathnames: + assert isinstance(fn, Filename) pathnames = stable_unique(pathnames) + for p in pathnames: + assert isinstance(p, Filename) pathnames = expand_py_files_from_args(pathnames) if not pathnames: logger.warning( @@ -103,8 +106,8 @@ def _get_python_path(env_var_name, default_path, target_dirname): # TODO: stop memoizing here after using StatCache. Actually just inline into # _ancestors_on_same_partition @memoize -def _get_st_dev(filename): - filename = Filename(filename) +def _get_st_dev(filename: Filename): + assert isinstance(filename, Filename) try: return os.stat(str(filename)).st_dev except OSError: @@ -159,7 +162,7 @@ def _expand_tripledots(pathnames, target_dirname): :rtype: ``list`` of `Filename` """ - target_dirname = Filename(target_dirname) + assert isinstance(target_dirname, Filename) if not isinstance(pathnames, (tuple, list)): pathnames = [pathnames] result = [] @@ -238,7 +241,7 @@ def clear_default_cache(cls): cls._default_cache.clear() @classmethod - def get_default(cls, target_filename): + def get_default(cls, target_filename: Union[Filename, str]): """ Return the default import library for the given target filename. @@ -263,11 +266,15 @@ def get_default(cls, target_filename): # TODO: Consider refreshing periodically. Check if files have # been touched, and if so, return new data. Check file timestamps at # most once every 60 seconds. - cache_keys = [] - target_filename = Filename(target_filename or ".") + cache_keys:List[Tuple[Any,...]] = [] + if target_filename: + if isinstance(target_filename, str): + target_filename = Filename(target_filename) + target_filename = target_filename or Filename(".") + assert isinstance(target_filename, Filename) if target_filename.startswith("/dev"): target_filename = Filename(".") - target_dirname = target_filename + target_dirname:Filename = target_filename # TODO: with StatCache while True: cache_keys.append((1, @@ -493,8 +500,11 @@ def _from_filenames(cls, filenames, _mandatory_filenames_deprecated=[]): `ImportDB` """ if not isinstance(filenames, (tuple, list)): + # TODO DeprecationWarning July 2024, + # this is internal deprecate not passing a list; filenames = [filenames] - filenames = [Filename(f) for f in filenames] + for f in filenames: + assert isinstance(f, Filename) logger.debug("ImportDB: loading [%s], mandatory=[%s]", ', '.join(map(str, filenames)), ', '.join(map(str, _mandatory_filenames_deprecated))) diff --git a/lib/python/pyflyby/_parse.py b/lib/python/pyflyby/_parse.py index f5bc3fc1..1f9920f5 100644 --- a/lib/python/pyflyby/_parse.py +++ b/lib/python/pyflyby/_parse.py @@ -23,10 +23,13 @@ from typing import Any, List, Optional, Tuple, Union, cast import warnings + _sentinel = object() if sys.version_info < (3, 10): + NoneType = type(None) + class MatchAs: name: str pattern: ast.AST @@ -36,6 +39,7 @@ class MatchMapping: patterns: List[MatchAs] else: + from types import NoneType from ast import MatchAs, MatchMapping @@ -963,6 +967,9 @@ def from_text(cls, text, filename=None, startpos=None, flags=None, :rtype: `PythonBlock` """ + if isinstance(filename, str): + filename = Filename(filename) + assert isinstance(filename, (Filename, NoneType)), filename text = FileText(text, filename=filename, startpos=startpos) self = object.__new__(cls) self.text = text diff --git a/lib/python/pyflyby/_py.py b/lib/python/pyflyby/_py.py index 9ad18369..ee65d71e 100644 --- a/lib/python/pyflyby/_py.py +++ b/lib/python/pyflyby/_py.py @@ -343,11 +343,11 @@ # TODO: add --profile, --runsnake import ast +import builtins from contextlib import contextmanager import inspect import os import re -import builtins import sys import types from types import FunctionType, MethodType @@ -1303,7 +1303,7 @@ def _has_python_shebang(filename): Note that this test is only needed for scripts found via which(), since otherwise the shebang is not necessary. """ - filename = Filename(filename) + assert isinstance(filename, Filename) with open(str(filename), 'rb') as f: line = f.readline(1024) return line.startswith(b"#!") and b"python" in line diff --git a/tests/test_autoimp.py b/tests/test_autoimp.py index 8c4fdf67..d5485c5d 100644 --- a/tests/test_autoimp.py +++ b/tests/test_autoimp.py @@ -17,9 +17,9 @@ auto_import, find_missing_imports) from pyflyby._autoimp import (LoadSymbolError, load_symbol, scan_for_import_issues) +from pyflyby._flags import CompilerFlags from pyflyby._idents import DottedIdentifier from pyflyby._importstmt import Import -from pyflyby._flags import CompilerFlags from pyflyby._util import CwdCtx @@ -46,7 +46,7 @@ def cleanup(): def writetext(filename, text, mode='w'): text = dedent(text) - filename = Filename(filename) + assert isinstance(filename, Filename) with open(str(filename), mode) as f: f.write(text) return filename @@ -2201,3 +2201,19 @@ def test_unsafe_filename_warning(tpp, capsys): [PYFLYBY] import pyflyby """).lstrip() assert out.startswith(expected) + + +def test_unsafe_filename_warning_II(tpp, capsys): + filepath = os.path.join(tpp._filename, "foo#bar") + os.mkdir(filepath) + filepath = os.path.join(filepath, "qux#baz") + os.mkdir(filepath) + with CwdCtx(filepath): + auto_import("pyflyby", [{}]) + out, _ = capsys.readouterr() + expected = dedent( + """ + [PYFLYBY] import pyflyby + """ + ).lstrip() + assert out.startswith(expected) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 5d707ba3..b4714f93 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -29,6 +29,7 @@ import pyflyby from pyflyby._file import Filename from pyflyby._util import EnvVarCtx, cached_attribute, memoize +from typing import Union # To debug test_interactive.py itself, set the env var DEBUG_TEST_PYFLYBY. @@ -134,10 +135,9 @@ def new_tempfile(self, dir=None): return Filename(f) - -def writetext(filename, text, mode='w'): +def writetext(filename: Filename, text: str, mode: str = "w") -> Filename: + assert isinstance(filename, Filename) text = dedent(text) - filename = Filename(filename) with open(str(filename), mode) as f: f.write(text) return filename @@ -473,7 +473,7 @@ def _extra_readline_pythonpath_dirs(): return (dir,) -def _build_pythonpath(PYTHONPATH): +def _build_pythonpath(PYTHONPATH) -> str: """ Build PYTHONPATH value to use. @@ -483,16 +483,22 @@ def _build_pythonpath(PYTHONPATH): pypath = [os.path.dirname(os.path.dirname(pyflyby.__file__))] if sys.version_info < (3, 9): pypath += _extra_readline_pythonpath_dirs() - if isinstance(PYTHONPATH, (Filename, str)): + if isinstance(PYTHONPATH, Filename): + PYTHONPATH = [str(PYTHONPATH)] + if isinstance(PYTHONPATH, str): PYTHONPATH = [PYTHONPATH] + for p in PYTHONPATH: + assert isinstance(p, str) PYTHONPATH = [str(Filename(d)) for d in PYTHONPATH] pypath += PYTHONPATH pypath += os.environ["PYTHONPATH"].split(":") return ":".join(pypath) -def _init_ipython_dir(ipython_dir): - ipython_dir = Filename(ipython_dir) +def _init_ipython_dir(ipython_dir: Union[Filename, str]): + if isinstance(ipython_dir, str): + ipython_dir = Filename(ipython_dir) + assert isinstance(ipython_dir, Filename) if _IPYTHON_VERSION >= (0, 11): os.makedirs(str(ipython_dir/"profile_default")) os.makedirs(str(ipython_dir/"profile_default/startup")) @@ -538,12 +544,14 @@ def _init_ipython_dir(ipython_dir): writetext(ipython_dir/"ipy_user_conf.py", "") -def _build_ipython_cmd(ipython_dir, prog="ipython", args=[], autocall=False, frontend=None): +def _build_ipython_cmd( + ipython_dir: Filename, prog="ipython", args=[], autocall=False, frontend=None +): """ Prepare the command to run IPython. """ python = sys.executable - ipython_dir = Filename(ipython_dir) + assert isinstance(ipython_dir, Filename) cmd = [python] if prog == "ipython" and _IPYTHON_VERSION >= (4,) and args and args[0] in ["console", "notebook"]: prog = "jupyter" @@ -834,7 +842,8 @@ def IPythonCtx(prog="ipython", __tracebackhide__ = True if hasattr(PYFLYBY_PATH, "write"): PYFLYBY_PATH = PYFLYBY_PATH.name - PYFLYBY_PATH = str(Filename(PYFLYBY_PATH)) + assert isinstance(PYFLYBY_PATH, Filename) + PYFLYBY_PATH = str(PYFLYBY_PATH) cleanup_dirs = [] # Create a temporary directory which we'll use as our IPYTHONDIR. if not ipython_dir: @@ -854,13 +863,16 @@ def IPythonCtx(prog="ipython", env = {} env["PYFLYBY_PATH"] = PYFLYBY_PATH env["PYFLYBY_LOG_LEVEL"] = PYFLYBY_LOG_LEVEL - env["PYTHONPATH"] = _build_pythonpath(PYTHONPATH) - env["PYTHONSTARTUP"] = "" - env["MPLCONFIGDIR"] = mplconfigdir - env["PATH"] = str(PYFLYBY_BIN.real) + os.path.pathsep + os.environ["PATH"] - env["PYTHONBREAKPOINT"] = "IPython.terminal.debugger.set_trace" - cmd = _build_ipython_cmd(ipython_dir, prog, args, autocall=autocall, - frontend=frontend) + env["PYTHONPATH"] = _build_pythonpath(PYTHONPATH) + env["PYTHONSTARTUP"] = "" + env["MPLCONFIGDIR"] = mplconfigdir + env["PATH"] = str(PYFLYBY_BIN.real) + os.path.pathsep + os.environ["PATH"] + env["PYTHONBREAKPOINT"] = "IPython.terminal.debugger.set_trace" + if isinstance(ipython_dir, str): + ipython_dir = Filename(ipython_dir) + cmd = _build_ipython_cmd( + ipython_dir, prog, args, autocall=autocall, frontend=frontend + ) # Spawn IPython. with EnvVarCtx(**env): print("# Spawning: %s" % (" ".join(cmd))) diff --git a/tests/test_livepatch.py b/tests/test_livepatch.py index aee226b4..cd9c2072 100644 --- a/tests/test_livepatch.py +++ b/tests/test_livepatch.py @@ -35,7 +35,7 @@ def cleanup(): def writetext(filename, text, mode='w'): text = dedent(text) - filename = Filename(filename) + assert isinstance(filename, Filename) with open(str(filename), mode) as f: f.write(text) return filename diff --git a/tests/test_py.py b/tests/test_py.py index 29a24d30..56de1fc2 100644 --- a/tests/test_py.py +++ b/tests/test_py.py @@ -14,11 +14,12 @@ from tempfile import NamedTemporaryFile, mkdtemp from textwrap import dedent -import pyflyby from pyflyby._file import Filename from pyflyby._util import cached_attribute +from tests.test_interactive import _build_pythonpath + PYFLYBY_HOME = Filename(__file__).real.dir.dir BIN_DIR = PYFLYBY_HOME / "bin" PYFLYBY_PATH = PYFLYBY_HOME / "etc/pyflyby" @@ -89,28 +90,15 @@ def new_tempdir(self): return Filename(d).real -def _build_pythonpath(PYTHONPATH): - """ - Build PYTHONPATH value to use. - - :rtype: - ``str`` - """ - pypath = [os.path.dirname(os.path.dirname(pyflyby.__file__))] - if isinstance(PYTHONPATH, (Filename, str)): - PYTHONPATH = [PYTHONPATH] - PYTHONPATH = [str(Filename(d)) for d in PYTHONPATH] - pypath += PYTHONPATH - pypath += os.environ["PYTHONPATH"].split(":") - return ":".join(pypath) - def _py_internal_1(args, stdin="", PYTHONPATH=[], PYFLYBY_PATH=PYFLYBY_PATH): env = dict(os.environ) - env["PYFLYBY_PATH" ] = str(Filename(PYFLYBY_PATH)) - env["PYTHONPATH" ] = _build_pythonpath(PYTHONPATH) + if isinstance(PYFLYBY_PATH, str): + PYFLYBY_PATH = Filename(PYFLYBY_PATH) + env["PYFLYBY_PATH"] = str(PYFLYBY_PATH) + env["PYTHONPATH"] = _build_pythonpath(PYTHONPATH) env["PYTHONSTARTUP"] = "" prog = str(BIN_DIR/"py") return pipe((prog,) + args, stdin=stdin, env=env) @@ -127,7 +115,7 @@ def py(*args, **kwargs): def writetext(filename, text, mode='w'): text = dedent(text) - filename = Filename(filename) + assert isinstance(filename, Filename) with open(str(filename), mode) as f: f.write(text) return filename