Skip to content

Commit

Permalink
Try to fix false positive undefined variable in match cases.
Browse files Browse the repository at this point in the history
That won't be perfect as the way pyflyback work the scope management is
a bit complex.
  • Loading branch information
Carreau committed Jul 8, 2024
1 parent 143fc01 commit 27bfad5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 8 deletions.
47 changes: 43 additions & 4 deletions lib/python/pyflyby/_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyflyby._log import logger
from pyflyby._modules import ModuleHandle
from pyflyby._parse import (PythonBlock, _is_ast_str,
infer_compile_mode)
infer_compile_mode, MatchAs)

from six import reraise
import sys
Expand Down Expand Up @@ -491,8 +491,11 @@ def visit(self, node):
logger.debug(
"_MissingImportFinder has no method %r, using generic_visit", method
)

visitor = getattr(self, method, self.generic_visit)
if hasattr(self, method):
visitor = getattr(self, method)
else:
logger.debug("No method `%s`, using `generic_visit`", method)
visitor = self.generic_visit
return visitor(node)
else:
raise TypeError("unexpected %s" % (type(node).__name__,))
Expand Down Expand Up @@ -827,6 +830,42 @@ def visit_alias(self, node, modulename=None):
self._visit_StoreImport(node, modulename)
self.generic_visit(node)

def visit_match_case(self, node:ast.match_case):
logger.debug("visit_match_case(%r)", node)
return self.generic_visit(node)

def visit_Call(self, node:ast.Call):
logger.debug("visit_Call(%r)", node)
return self.generic_visit(node)

def visit_Pass(self, node:ast.Pass):
logger.debug("visit_Pass(%r)", node)
return self.generic_visit(node)

def visit_Constant(self, node:ast.Constant):
logger.debug("visit_Constant(%r)", node)
return self.generic_visit(node)

def visit_Module(self, node:ast.Module):
logger.debug("visit_Module(%r)", node)
return self.generic_visit(node)

def visit_Match(self, node:ast.Match):
logger.debug("visit_Match(%r)", node)
return self.generic_visit(node)

def visit_MatchMapping(self, node:ast.MatchMapping):
logger.debug("visit_MatchMapping(%r)", node)
return self.generic_visit(node)

def visit_Expr(self, node:ast.Expr):
logger.debug("visit_Expr(%r)", node)
return self.generic_visit(node)

def visit_MatchAs(self, node:MatchAs):
logger.debug("visit_MatchAs(%r)", node)
return self._visit_Store(node.name)

def visit_Name(self, node):
logger.debug("visit_Name(%r)", node.id)
self._visit_fullname(node.id, node.ctx)
Expand Down Expand Up @@ -884,7 +923,7 @@ def _visit_StoreImport(self, node, modulename):
value = _UseChecker(name, imp, self._lineno)
self._visit_Store(name, value)

def _visit_Store(self, fullname, value=None):
def _visit_Store(self, fullname:str, value=None):
logger.debug("_visit_Store(%r)", fullname)
if fullname is None:
return
Expand Down
1 change: 1 addition & 0 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _iter_child_nodes_in_order_internal_1(node):
yield node.name,
elif isinstance(node, MatchMapping):
for k, p in zip(node.keys, node.patterns):
pass
yield k, p
else:
# Default behavior.
Expand Down
49 changes: 46 additions & 3 deletions tests/test_autoimp.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def foo(self):
assert expected == result


@pytest.mark.xfail
@pytest.mark.xfail(strict=True)
def test_find_missing_import_xfail_after_pr_152():
code = dedent(
"""
Expand Down Expand Up @@ -538,7 +538,8 @@ class B:


@pytest.mark.xfail(
reason="Had to deactivate as part of https://github.com/deshaw/pyflyby/pull/269/files conflicting requirements"
reason="Had to deactivate as part of https://github.com/deshaw/pyflyby/pull/269/files conflicting requirements",
strict=True,
)
def test_find_missing_imports_class_name_1():
code = dedent(
Expand Down Expand Up @@ -1146,6 +1147,48 @@ def test_find_missing_imports_true_false_none_1():
assert expected == result


@pytest.mark.skipif(sys.version_info < (3, 10), reason='No pattern matching before 3.10')
def test_find_missing_imports_pattern_match_1():
code = dedent("""
match {"foo": 1, "bar": 2}:
case {
"foo": the_foo_value,
"bar": the_bar_value,
**rest,
}:
print(the_foo_value)
case _:
pass
""")
result = find_missing_imports(code, [{}])
result = _dilist2strlist(result)
expected = []
assert expected == result

@pytest.mark.xfail(reason='''The way the scope work in pyflyby it is hard to define a variable...
only in one case I believe. We would need a scope stack in `def
visit_match_case`, but that would remove the variable definition
when leaving the match statement.
''',strict=True)
@pytest.mark.skipif(sys.version_info < (3, 10), reason='No pattern matching before 3.10')
def test_find_missing_imports_pattern_match_2():
code = dedent("""
match {"foo": 1, "bar": 2}:
case {
"foo": the_foo_value,
"bar": the_bar_value,
**rest,
}:
print(the_foo_value)
case _:
print('here the_x_value might be unknown', the_foo_value)
""")
result = find_missing_imports(code, [{}])
result = _dilist2strlist(result)
expected = [DottedIdentifier('the_foo_value')]
assert expected == result


def test_find_missing_imports_matmul_1():
code = dedent("""
a@b
Expand Down Expand Up @@ -1572,7 +1615,7 @@ def test_scan_for_import_issues_comprehension_attribute_1():
assert unused == []


@pytest.mark.xfail
@pytest.mark.xfail(strict=True)
def test_scan_for_import_issues_comprehension_attribute_missing_1():
code = dedent("""
[123 for xx.yy in []]
Expand Down
13 changes: 12 additions & 1 deletion tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,18 @@ def http_error(status):
case _:
return "Something's wrong with the Internet"
"""
""",
"""
match {"foo": 1, "bar": 2}:
case {
"foo": foo,
"bar": bar,
**rest,
}:
print(foo)
case _:
pass
""",
]
]
)
Expand Down

0 comments on commit 27bfad5

Please sign in to comment.