Skip to content

Commit

Permalink
MAINT: be stricter on str/Filename
Browse files Browse the repository at this point in the history
There are a bit too many places where there is a str/Filename Union.
This make it hard to track issues like #346.

Thus we do some refactor and enforce Filename in many places in order to
make reasonning on the code easier.
  • Loading branch information
Carreau committed Jul 16, 2024
1 parent 6e1f2ab commit e415af6
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 74 deletions.
9 changes: 6 additions & 3 deletions lib/python/pyflyby/_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@



from builtins import input
import optparse
import os
import signal
from builtins import input
import sys
from textwrap import dedent
import traceback
Expand Down Expand Up @@ -285,15 +285,18 @@ def unaligned_callback(option, opt_str, value, parser):
def _default_on_error(filename):
raise SystemExit("bad filename %s" % (filename,))

def filename_args(args, on_error=_default_on_error):

def filename_args(args: List[str], on_error=_default_on_error):
"""
Return list of filenames given command-line arguments.
:rtype:
``list`` of `Filename`
"""
if args:
return expand_py_files_from_args(args, on_error)
for a in args:
assert isinstance(a, str)
return expand_py_files_from_args([Filename(f) for f in args], on_error)
elif not os.isatty(0):
return [Filename.STDIN]
else:
Expand Down
53 changes: 35 additions & 18 deletions lib/python/pyflyby/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
# License: MIT http://opensource.org/licenses/MIT
from __future__ import annotations

from functools import total_ordering, cached_property
from functools import cached_property, total_ordering
import io
import os
import re
import sys
from typing import Optional, Tuple, ClassVar
from typing import ClassVar, List, Optional, Tuple, Union

from pyflyby._util import cmp, memoize

if sys.version_info < (3,10):
NoneType = type(None)
else:
from types import NoneType


class UnsafeFilenameError(ValueError):
pass
Expand All @@ -33,20 +38,23 @@ class Filename(object):

def __new__(cls, arg):
if isinstance(arg, cls):
return arg
# TODO make this assert False
return cls._from_filename(arg._filename)
if isinstance(arg, str):
return cls._from_filename(arg)
raise TypeError

@classmethod
def _from_filename(cls, filename):
def _from_filename(cls, filename: str):
if not isinstance(filename, str):
raise TypeError
filename = str(filename)
if not filename:
raise UnsafeFilenameError("(empty string)")
if re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename):
raise UnsafeFilenameError(filename)
# we only allow filename with given character set
match = re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename)
if match:
raise UnsafeFilenameError((filename, match))
if re.search("(^|/)~", filename):
raise UnsafeFilenameError(filename)
self = object.__new__(cls)
Expand Down Expand Up @@ -373,6 +381,8 @@ def __new__(cls, arg, filename=None, startpos=None):
:rtype:
``FileText``
"""
if isinstance(filename, str):
filename = Filename(filename)
if isinstance(arg, cls):
if filename is startpos is None:
return arg
Expand All @@ -387,8 +397,8 @@ def __new__(cls, arg, filename=None, startpos=None):
else:
raise TypeError("%s: unexpected %s"
% (cls.__name__, type(arg).__name__))
if filename is not None:
filename = Filename(filename)

assert isinstance(filename, (Filename, NoneType))
startpos = FilePos(startpos)
self.filename = filename
self.startpos = startpos
Expand Down Expand Up @@ -439,9 +449,9 @@ def joined(self) -> str:
def from_filename(cls, filename):
return cls.from_lines(Filename(filename))

def alter(self, filename=None, startpos=None):
def alter(self, filename: Optional[Filename] = None, startpos=None):
if filename is not None:
filename = Filename(filename)
assert isinstance(filename, Filename)
else:
filename = self.filename
if startpos is not None:
Expand Down Expand Up @@ -652,23 +662,24 @@ def __hash__(self):
return h


def read_file(filename):
filename = Filename(filename)
def read_file(filename: Filename) -> FileText:
assert isinstance(filename, Filename)
if filename == Filename.STDIN:
data = sys.stdin.read()
else:
with io.open(str(filename), 'r') as f:
data = f.read()
return FileText(data, filename=filename)

def write_file(filename, data):
filename = Filename(filename)

def write_file(filename: Filename, data):
assert isinstance(filename, Filename)
data = FileText(data)
with open(str(filename), 'w') as f:
f.write(data.joined)

def atomic_write_file(filename, data):
filename = Filename(filename)
def atomic_write_file(filename: Filename, data):
assert isinstance(filename, Filename)
data = FileText(data)
temp_filename = Filename("%s.tmp.%s" % (filename, os.getpid(),))
write_file(temp_filename, data)
Expand All @@ -680,7 +691,10 @@ def atomic_write_file(filename, data):
pass
os.rename(str(temp_filename), str(filename))

def expand_py_files_from_args(pathnames, on_error=lambda filename: None):

def expand_py_files_from_args(
pathnames: Union[List[Filename], Filename], on_error=lambda filename: None
):
"""
Enumerate ``*.py`` files, recursively.
Expand All @@ -698,8 +712,11 @@ def expand_py_files_from_args(pathnames, on_error=lambda filename: None):
``list`` of `Filename` s
"""
if not isinstance(pathnames, (tuple, list)):
# July 2024 DeprecationWarning
# this seem to be used only internally, maybe deprecate not passing a list.
pathnames = [pathnames]
pathnames = [Filename(f) for f in pathnames]
for f in pathnames:
assert isinstance(f, Filename)
result = []
# Check for problematic arguments. Note that we intentionally only do
# this for directly specified arguments, not for recursively traversed
Expand Down
34 changes: 22 additions & 12 deletions lib/python/pyflyby/_importdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import re
import sys

from typing import Dict, Any, Tuple
from typing import Any, Dict, Tuple, Union, List


from pyflyby._file import Filename, expand_py_files_from_args, UnsafeFilenameError
from pyflyby._file import (Filename, UnsafeFilenameError,
expand_py_files_from_args)
from pyflyby._idents import dotted_prefixes
from pyflyby._importclns import ImportMap, ImportSet
from pyflyby._importstmt import Import, ImportStatement
Expand Down Expand Up @@ -90,8 +90,11 @@ def _get_python_path(env_var_name, default_path, target_dirname):
.format(env_var_name=env_var_name, p=p))
pathnames = [os.path.expanduser(p) for p in pathnames]
pathnames = _expand_tripledots(pathnames, target_dirname)
pathnames = [Filename(fn) for fn in pathnames]
for fn in pathnames:
assert isinstance(fn, Filename)
pathnames = stable_unique(pathnames)
for p in pathnames:
assert isinstance(p, Filename)
pathnames = expand_py_files_from_args(pathnames)
if not pathnames:
logger.warning(
Expand All @@ -103,8 +106,8 @@ def _get_python_path(env_var_name, default_path, target_dirname):
# TODO: stop memoizing here after using StatCache. Actually just inline into
# _ancestors_on_same_partition
@memoize
def _get_st_dev(filename):
filename = Filename(filename)
def _get_st_dev(filename: Filename):
assert isinstance(filename, Filename)
try:
return os.stat(str(filename)).st_dev
except OSError:
Expand Down Expand Up @@ -159,7 +162,7 @@ def _expand_tripledots(pathnames, target_dirname):
:rtype:
``list`` of `Filename`
"""
target_dirname = Filename(target_dirname)
assert isinstance(target_dirname, Filename)
if not isinstance(pathnames, (tuple, list)):
pathnames = [pathnames]
result = []
Expand Down Expand Up @@ -238,7 +241,7 @@ def clear_default_cache(cls):
cls._default_cache.clear()

@classmethod
def get_default(cls, target_filename):
def get_default(cls, target_filename: Union[Filename, str]):
"""
Return the default import library for the given target filename.
Expand All @@ -263,11 +266,15 @@ def get_default(cls, target_filename):
# TODO: Consider refreshing periodically. Check if files have
# been touched, and if so, return new data. Check file timestamps at
# most once every 60 seconds.
cache_keys = []
target_filename = Filename(target_filename or ".")
cache_keys:List[Tuple[Any,...]] = []
if target_filename:
if isinstance(target_filename, str):
target_filename = Filename(target_filename)
target_filename = target_filename or Filename(".")
assert isinstance(target_filename, Filename)
if target_filename.startswith("/dev"):
target_filename = Filename(".")
target_dirname = target_filename
target_dirname:Filename = target_filename
# TODO: with StatCache
while True:
cache_keys.append((1,
Expand Down Expand Up @@ -493,8 +500,11 @@ def _from_filenames(cls, filenames, _mandatory_filenames_deprecated=[]):
`ImportDB`
"""
if not isinstance(filenames, (tuple, list)):
# TODO DeprecationWarning July 2024,
# this is internal deprecate not passing a list;
filenames = [filenames]
filenames = [Filename(f) for f in filenames]
for f in filenames:
assert isinstance(f, Filename)
logger.debug("ImportDB: loading [%s], mandatory=[%s]",
', '.join(map(str, filenames)),
', '.join(map(str, _mandatory_filenames_deprecated)))
Expand Down
7 changes: 7 additions & 0 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
from typing import Any, List, Optional, Tuple, Union, cast
import warnings


_sentinel = object()

if sys.version_info < (3, 10):

NoneType = type(None)

class MatchAs:
name: str
pattern: ast.AST
Expand All @@ -36,6 +39,7 @@ class MatchMapping:
patterns: List[MatchAs]

else:
from types import NoneType
from ast import MatchAs, MatchMapping


Expand Down Expand Up @@ -963,6 +967,9 @@ def from_text(cls, text, filename=None, startpos=None, flags=None,
:rtype:
`PythonBlock`
"""
if isinstance(filename, str):
filename = Filename(filename)
assert isinstance(filename, (Filename, NoneType)), filename
text = FileText(text, filename=filename, startpos=startpos)
self = object.__new__(cls)
self.text = text
Expand Down
4 changes: 2 additions & 2 deletions lib/python/pyflyby/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,11 @@
# TODO: add --profile, --runsnake

import ast
import builtins
from contextlib import contextmanager
import inspect
import os
import re
import builtins
import sys
import types
from types import FunctionType, MethodType
Expand Down Expand Up @@ -1303,7 +1303,7 @@ def _has_python_shebang(filename):
Note that this test is only needed for scripts found via which(), since
otherwise the shebang is not necessary.
"""
filename = Filename(filename)
assert isinstance(filename, Filename)
with open(str(filename), 'rb') as f:
line = f.readline(1024)
return line.startswith(b"#!") and b"python" in line
Expand Down
20 changes: 18 additions & 2 deletions tests/test_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
auto_import, find_missing_imports)
from pyflyby._autoimp import (LoadSymbolError, load_symbol,
scan_for_import_issues)
from pyflyby._flags import CompilerFlags
from pyflyby._idents import DottedIdentifier
from pyflyby._importstmt import Import
from pyflyby._flags import CompilerFlags
from pyflyby._util import CwdCtx


Expand All @@ -46,7 +46,7 @@ def cleanup():

def writetext(filename, text, mode='w'):
text = dedent(text)
filename = Filename(filename)
assert isinstance(filename, Filename)
with open(str(filename), mode) as f:
f.write(text)
return filename
Expand Down Expand Up @@ -2201,3 +2201,19 @@ def test_unsafe_filename_warning(tpp, capsys):
[PYFLYBY] import pyflyby
""").lstrip()
assert out.startswith(expected)


def test_unsafe_filename_warning_II(tpp, capsys):
filepath = os.path.join(tpp._filename, "foo#bar")
os.mkdir(filepath)
filepath = os.path.join(filepath, "qux#baz")
os.mkdir(filepath)
with CwdCtx(filepath):
auto_import("pyflyby", [{}])
out, _ = capsys.readouterr()
expected = dedent(
"""
[PYFLYBY] import pyflyby
"""
).lstrip()
assert out.startswith(expected)
Loading

0 comments on commit e415af6

Please sign in to comment.