Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: be stricter on str/Filename #350

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions lib/python/pyflyby/_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@



from builtins import input
import optparse
import os
import signal
from builtins import input
import sys
from textwrap import dedent
import traceback
Expand Down Expand Up @@ -285,15 +285,18 @@ def unaligned_callback(option, opt_str, value, parser):
def _default_on_error(filename):
raise SystemExit("bad filename %s" % (filename,))

def filename_args(args, on_error=_default_on_error):

def filename_args(args: List[str], on_error=_default_on_error):
"""
Return list of filenames given command-line arguments.

:rtype:
``list`` of `Filename`
"""
if args:
return expand_py_files_from_args(args, on_error)
for a in args:
assert isinstance(a, str)
return expand_py_files_from_args([Filename(f) for f in args], on_error)
elif not os.isatty(0):
return [Filename.STDIN]
else:
Expand Down
53 changes: 35 additions & 18 deletions lib/python/pyflyby/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
# License: MIT http://opensource.org/licenses/MIT
from __future__ import annotations

from functools import total_ordering, cached_property
from functools import cached_property, total_ordering
import io
import os
import re
import sys
from typing import Optional, Tuple, ClassVar
from typing import ClassVar, List, Optional, Tuple, Union

from pyflyby._util import cmp, memoize

if sys.version_info < (3,10):
NoneType = type(None)
else:
from types import NoneType


class UnsafeFilenameError(ValueError):
pass
Expand All @@ -33,20 +38,23 @@ class Filename(object):

def __new__(cls, arg):
if isinstance(arg, cls):
return arg
# TODO make this assert False
return cls._from_filename(arg._filename)
if isinstance(arg, str):
return cls._from_filename(arg)
raise TypeError

@classmethod
def _from_filename(cls, filename):
def _from_filename(cls, filename: str):
if not isinstance(filename, str):
raise TypeError
filename = str(filename)
if not filename:
raise UnsafeFilenameError("(empty string)")
if re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename):
raise UnsafeFilenameError(filename)
# we only allow filename with given character set
match = re.search("[^a-zA-Z0-9_=+{}/.,~@-]", filename)
if match:
raise UnsafeFilenameError((filename, match))
if re.search("(^|/)~", filename):
raise UnsafeFilenameError(filename)
self = object.__new__(cls)
Expand Down Expand Up @@ -373,6 +381,8 @@ def __new__(cls, arg, filename=None, startpos=None):
:rtype:
``FileText``
"""
if isinstance(filename, str):
filename = Filename(filename)
if isinstance(arg, cls):
if filename is startpos is None:
return arg
Expand All @@ -387,8 +397,8 @@ def __new__(cls, arg, filename=None, startpos=None):
else:
raise TypeError("%s: unexpected %s"
% (cls.__name__, type(arg).__name__))
if filename is not None:
filename = Filename(filename)

assert isinstance(filename, (Filename, NoneType))
startpos = FilePos(startpos)
self.filename = filename
self.startpos = startpos
Expand Down Expand Up @@ -439,9 +449,9 @@ def joined(self) -> str:
def from_filename(cls, filename):
return cls.from_lines(Filename(filename))

def alter(self, filename=None, startpos=None):
def alter(self, filename: Optional[Filename] = None, startpos=None):
if filename is not None:
filename = Filename(filename)
assert isinstance(filename, Filename)
else:
filename = self.filename
if startpos is not None:
Expand Down Expand Up @@ -652,23 +662,24 @@ def __hash__(self):
return h


def read_file(filename):
filename = Filename(filename)
def read_file(filename: Filename) -> FileText:
assert isinstance(filename, Filename)
if filename == Filename.STDIN:
data = sys.stdin.read()
else:
with io.open(str(filename), 'r') as f:
data = f.read()
return FileText(data, filename=filename)

def write_file(filename, data):
filename = Filename(filename)

def write_file(filename: Filename, data):
assert isinstance(filename, Filename)
data = FileText(data)
with open(str(filename), 'w') as f:
f.write(data.joined)

def atomic_write_file(filename, data):
filename = Filename(filename)
def atomic_write_file(filename: Filename, data):
assert isinstance(filename, Filename)
data = FileText(data)
temp_filename = Filename("%s.tmp.%s" % (filename, os.getpid(),))
write_file(temp_filename, data)
Expand All @@ -680,7 +691,10 @@ def atomic_write_file(filename, data):
pass
os.rename(str(temp_filename), str(filename))

def expand_py_files_from_args(pathnames, on_error=lambda filename: None):

def expand_py_files_from_args(
pathnames: Union[List[Filename], Filename], on_error=lambda filename: None
):
"""
Enumerate ``*.py`` files, recursively.

Expand All @@ -698,8 +712,11 @@ def expand_py_files_from_args(pathnames, on_error=lambda filename: None):
``list`` of `Filename` s
"""
if not isinstance(pathnames, (tuple, list)):
# July 2024 DeprecationWarning
# this seem to be used only internally, maybe deprecate not passing a list.
pathnames = [pathnames]
pathnames = [Filename(f) for f in pathnames]
for f in pathnames:
assert isinstance(f, Filename)
result = []
# Check for problematic arguments. Note that we intentionally only do
# this for directly specified arguments, not for recursively traversed
Expand Down
34 changes: 22 additions & 12 deletions lib/python/pyflyby/_importdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import re
import sys

from typing import Dict, Any, Tuple
from typing import Any, Dict, Tuple, Union, List


from pyflyby._file import Filename, expand_py_files_from_args, UnsafeFilenameError
from pyflyby._file import (Filename, UnsafeFilenameError,
expand_py_files_from_args)
from pyflyby._idents import dotted_prefixes
from pyflyby._importclns import ImportMap, ImportSet
from pyflyby._importstmt import Import, ImportStatement
Expand Down Expand Up @@ -90,8 +90,11 @@ def _get_python_path(env_var_name, default_path, target_dirname):
.format(env_var_name=env_var_name, p=p))
pathnames = [os.path.expanduser(p) for p in pathnames]
pathnames = _expand_tripledots(pathnames, target_dirname)
pathnames = [Filename(fn) for fn in pathnames]
for fn in pathnames:
assert isinstance(fn, Filename)
pathnames = stable_unique(pathnames)
for p in pathnames:
assert isinstance(p, Filename)
pathnames = expand_py_files_from_args(pathnames)
if not pathnames:
logger.warning(
Expand All @@ -103,8 +106,8 @@ def _get_python_path(env_var_name, default_path, target_dirname):
# TODO: stop memoizing here after using StatCache. Actually just inline into
# _ancestors_on_same_partition
@memoize
def _get_st_dev(filename):
filename = Filename(filename)
def _get_st_dev(filename: Filename):
assert isinstance(filename, Filename)
try:
return os.stat(str(filename)).st_dev
except OSError:
Expand Down Expand Up @@ -159,7 +162,7 @@ def _expand_tripledots(pathnames, target_dirname):
:rtype:
``list`` of `Filename`
"""
target_dirname = Filename(target_dirname)
assert isinstance(target_dirname, Filename)
if not isinstance(pathnames, (tuple, list)):
pathnames = [pathnames]
result = []
Expand Down Expand Up @@ -238,7 +241,7 @@ def clear_default_cache(cls):
cls._default_cache.clear()

@classmethod
def get_default(cls, target_filename):
def get_default(cls, target_filename: Union[Filename, str]):
"""
Return the default import library for the given target filename.

Expand All @@ -263,11 +266,15 @@ def get_default(cls, target_filename):
# TODO: Consider refreshing periodically. Check if files have
# been touched, and if so, return new data. Check file timestamps at
# most once every 60 seconds.
cache_keys = []
target_filename = Filename(target_filename or ".")
cache_keys:List[Tuple[Any,...]] = []
if target_filename:
if isinstance(target_filename, str):
target_filename = Filename(target_filename)
target_filename = target_filename or Filename(".")
assert isinstance(target_filename, Filename)
if target_filename.startswith("/dev"):
target_filename = Filename(".")
target_dirname = target_filename
target_dirname:Filename = target_filename
# TODO: with StatCache
while True:
cache_keys.append((1,
Expand Down Expand Up @@ -493,8 +500,11 @@ def _from_filenames(cls, filenames, _mandatory_filenames_deprecated=[]):
`ImportDB`
"""
if not isinstance(filenames, (tuple, list)):
# TODO DeprecationWarning July 2024,
# this is internal deprecate not passing a list;
filenames = [filenames]
filenames = [Filename(f) for f in filenames]
for f in filenames:
assert isinstance(f, Filename)
logger.debug("ImportDB: loading [%s], mandatory=[%s]",
', '.join(map(str, filenames)),
', '.join(map(str, _mandatory_filenames_deprecated)))
Expand Down
7 changes: 7 additions & 0 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
from typing import Any, List, Optional, Tuple, Union, cast
import warnings


_sentinel = object()

if sys.version_info < (3, 10):

NoneType = type(None)

class MatchAs:
name: str
pattern: ast.AST
Expand All @@ -36,6 +39,7 @@ class MatchMapping:
patterns: List[MatchAs]

else:
from types import NoneType
from ast import MatchAs, MatchMapping


Expand Down Expand Up @@ -963,6 +967,9 @@ def from_text(cls, text, filename=None, startpos=None, flags=None,
:rtype:
`PythonBlock`
"""
if isinstance(filename, str):
filename = Filename(filename)
assert isinstance(filename, (Filename, NoneType)), filename
text = FileText(text, filename=filename, startpos=startpos)
self = object.__new__(cls)
self.text = text
Expand Down
4 changes: 2 additions & 2 deletions lib/python/pyflyby/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,11 @@
# TODO: add --profile, --runsnake

import ast
import builtins
from contextlib import contextmanager
import inspect
import os
import re
import builtins
import sys
import types
from types import FunctionType, MethodType
Expand Down Expand Up @@ -1303,7 +1303,7 @@ def _has_python_shebang(filename):
Note that this test is only needed for scripts found via which(), since
otherwise the shebang is not necessary.
"""
filename = Filename(filename)
assert isinstance(filename, Filename)
with open(str(filename), 'rb') as f:
line = f.readline(1024)
return line.startswith(b"#!") and b"python" in line
Expand Down
20 changes: 18 additions & 2 deletions tests/test_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
auto_import, find_missing_imports)
from pyflyby._autoimp import (LoadSymbolError, load_symbol,
scan_for_import_issues)
from pyflyby._flags import CompilerFlags
from pyflyby._idents import DottedIdentifier
from pyflyby._importstmt import Import
from pyflyby._flags import CompilerFlags
from pyflyby._util import CwdCtx


Expand All @@ -46,7 +46,7 @@ def cleanup():

def writetext(filename, text, mode='w'):
text = dedent(text)
filename = Filename(filename)
assert isinstance(filename, Filename)
with open(str(filename), mode) as f:
f.write(text)
return filename
Expand Down Expand Up @@ -2201,3 +2201,19 @@ def test_unsafe_filename_warning(tpp, capsys):
[PYFLYBY] import pyflyby
""").lstrip()
assert out.startswith(expected)


def test_unsafe_filename_warning_II(tpp, capsys):
filepath = os.path.join(tpp._filename, "foo#bar")
os.mkdir(filepath)
filepath = os.path.join(filepath, "qux#baz")
os.mkdir(filepath)
with CwdCtx(filepath):
auto_import("pyflyby", [{}])
out, _ = capsys.readouterr()
expected = dedent(
"""
[PYFLYBY] import pyflyby
"""
).lstrip()
assert out.startswith(expected)
Loading
Loading