Skip to content

Commit

Permalink
Merge pull request #269 from aktech/forward-ref-fix
Browse files Browse the repository at this point in the history
Fix importing issues with forward references
  • Loading branch information
Carreau authored Dec 5, 2023
2 parents a60a494 + dbd1a2f commit 78d68d5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 12 deletions.
9 changes: 6 additions & 3 deletions lib/python/pyflyby/_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,17 @@ def _visit_Store(self, fullname, value=None):
def _remove_from_missing_imports(self, fullname):
for missing_import in self.missing_imports:
# If it was defined inside a class method, then it wouldn't have been added to
# the missing imports anyways.
# the missing imports anyways (except in that case of annotations)
# See the following tests:
# - tests.test_autoimp.test_method_reference_current_class
# - tests.test_autoimp.test_find_missing_imports_class_name_1
# - tests.test_autoimp.test_scan_for_import_issues_class_defined_after_use
scopestack = missing_import[1].scope_info['scopestack']
in_class_scope = isinstance(scopestack[-1], _ClassScope)
inside_class = missing_import[1].scope_info.get('_in_class_def')
if missing_import[1].startswith(fullname) and not inside_class:
self.missing_imports.remove(missing_import)
if missing_import[1].startswith(fullname):
if in_class_scope or not inside_class:
self.missing_imports.remove(missing_import)

def _get_scope_info(self):
return {
Expand Down
54 changes: 46 additions & 8 deletions tests/test_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,20 +521,58 @@ def foo():
assert unused == []


def test_annotation_inside_class():
code = dedent(
"""
class A:
param1: str
param2: B
class B:
param1: str
"""
)
missing, unused = scan_for_import_issues(code, [{}])
assert missing == []
assert unused == []


@pytest.mark.xfail(
reason="Had to deactivate as part of https://github.com/deshaw/pyflyby/pull/269/files conflicting requirements"
)
def test_find_missing_imports_class_name_1():
code = dedent(
"""
class Corinne(object):
class Corinne:
pass
class Bobtail(object):
class Chippewa(object):
Bobtail
class Bobtail:
class Chippewa:
Bobtail # will be name error at runtime
Rockton = Passall, Corinne, Chippewa
""")
result = find_missing_imports(code, [{}])
result = _dilist2strlist(result)
expected = ['Bobtail', 'Passall']
# ^error, ^ok , ^ok
"""
)
result = find_missing_imports(code, [{}])
result = _dilist2strlist(result)
expected = ["Bobtail", "Passall"]
assert expected == result


def test_find_missing_imports_class_name_1b():
code = dedent(
"""
class Corinne:
pass
class Bobtail:
class Chippewa:
Bobtail # will be name error at runtime
Rockton = Passall, Corinne, Chippewa
# ^error, ^ok , ^ok
"""
)
result = find_missing_imports(code, [{}])
result = _dilist2strlist(result)
expected = ["Passall"]
assert expected == result


Expand Down
48 changes: 47 additions & 1 deletion tests/test_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tempfile
from textwrap import dedent

from pyflyby._util import EnvVarCtx
from pyflyby._util import EnvVarCtx, CwdCtx

import pytest

Expand Down Expand Up @@ -764,3 +764,49 @@ def test_tidy_imports_sorting():
sympy
""").strip().format(f=f)
assert result == expected


def test_tidy_imports_forward_references():
with tempfile.TemporaryDirectory() as temp_dir:
foo = os.path.join(temp_dir, "foo.py")
with open(foo, "w") as foo_fp:
foo_fp.write(dedent("""
from __future__ import annotations
class A:
param1: str
param2: B
class B:
param1: str
""").lstrip())
foo_fp.flush()

dot_pyflyby = os.path.join(temp_dir, ".pyflyby")
with open(dot_pyflyby, "w") as dot_pyflyby_fp:
dot_pyflyby_fp.write(dedent("""
from foo import A, B
""").lstrip())
dot_pyflyby_fp.flush()
with CwdCtx(temp_dir):
result = pipe(
[BIN_DIR + "/tidy-imports", foo_fp.name],
env={"PYFLYBY_PATH": dot_pyflyby},
)

expected = dedent(
"""
from __future__ import annotations
class A:
param1: str
param2: B
class B:
param1: str
"""
).strip()
assert result == expected

0 comments on commit 78d68d5

Please sign in to comment.