diff --git a/circuitpython_build_tools/munge.py b/circuitpython_build_tools/munge.py index c08c693..4026efc 100644 --- a/circuitpython_build_tools/munge.py +++ b/circuitpython_build_tools/munge.py @@ -90,14 +90,15 @@ def process_statement(node): break return elif isinstance(node, ast.If): + node_test = ast.unparse(node.test) # return the statements in the 'if' branch of 'if sys.implementation...: ...' - if ast.unparse(node.test) == sys_implementation_is_circuitpython: + if node_test == sys_implementation_is_circuitpython: replace(node.lineno, 'if 1:') # return the statements in the 'else' branch of 'if sys.implementation...: ...' - if ast.unparse(node.test) == sys_implementation_not_circuitpython or ast.unparse(node.test) == sys_implementation_not_circuitpython2: + elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2: replace(node.lineno, 'if 0:') # return the statements in the else branch of 'if TYPE_CHECKING: ...' - elif ast.unparse(node.test) == 'TYPE_CHECKING': + elif node_test == 'TYPE_CHECKING': replace(node.lineno, 'if 0:') elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__': replace(node.lineno, f"__version__ = \"{version_str}\"") diff --git a/tests/test_munge.py b/tests/test_munge.py index acc0c9c..48e95f2 100644 --- a/tests/test_munge.py +++ b/tests/test_munge.py @@ -14,9 +14,9 @@ def test_munge(test_path): result_content = munge(test_path, "1.2.3") result_path.write_text(result_content, encoding="utf-8") - expected = test_path.with_suffix(".exp") - expected_content = expected.read_text(encoding="utf-8") + expected_path = test_path.with_suffix(".exp") + expected_content = expected_path.read_text(encoding="utf-8") - assert result == expected + assert result_content == expected_content result_path.unlink()