diff --git a/src/latexify/ast_utils.py b/src/latexify/ast_utils.py index 718fc88..d9da518 100644 --- a/src/latexify/ast_utils.py +++ b/src/latexify/ast_utils.py @@ -96,6 +96,21 @@ def is_constant(node: ast.AST) -> bool: return isinstance(node, ast.Constant) +def is_str(node: ast.AST) -> bool: + """Checks if the node is a str constant. + + Args: + node: The node to examine. + + Returns: + True if the node is a str constant, False otherwise. + """ + if sys.version_info.minor < 8 and isinstance(node, ast.Str): + return True + + return isinstance(node, ast.Constant) and isinstance(node.value, str) + + def extract_int_or_none(node: ast.expr) -> int | None: """Extracts int constant from the given Constant node. diff --git a/src/latexify/ast_utils_test.py b/src/latexify/ast_utils_test.py index c922fd3..0e9cfa2 100644 --- a/src/latexify/ast_utils_test.py +++ b/src/latexify/ast_utils_test.py @@ -114,6 +114,38 @@ def test_is_constant(value: ast.AST, expected: bool) -> None: assert ast_utils.is_constant(value) is expected +@test_utils.require_at_most(7) +@pytest.mark.parametrize( + "value,expected", + [ + (ast.Bytes(s=b"foo"), False), + (ast.Constant("bar"), True), + (ast.Ellipsis(), False), + (ast.NameConstant(value=None), False), + (ast.Num(n=123), False), + (ast.Str(s="baz"), True), + (ast.Expr(value=ast.Num(456)), False), + (ast.Global(names=["qux"]), False), + ], +) +def test_is_str_legacy(value: ast.AST, expected: bool) -> None: + assert ast_utils.is_str(value) is expected + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "value,expected", + [ + (ast.Constant(value=123), False), + (ast.Constant(value="foo"), True), + (ast.Expr(value=ast.Constant(value="foo")), False), + (ast.Global(names=["foo"]), False), + ], +) +def test_is_str(value: ast.AST, expected: bool) -> None: + assert ast_utils.is_str(value) is expected + + def test_extract_int_or_none() -> None: assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123 assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0 diff --git a/src/latexify/generate_latex.py b/src/latexify/generate_latex.py index dbcf0e8..0a7615d 100644 --- a/src/latexify/generate_latex.py +++ b/src/latexify/generate_latex.py @@ -56,6 +56,7 @@ def get_latex( if merged_config.identifiers is not None: tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree) if merged_config.reduce_assignments: + tree = transformers.DocstringRemover().visit(tree) tree = transformers.AssignmentReducer().visit(tree) if merged_config.expand_functions is not None: tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) diff --git a/src/latexify/generate_latex_test.py b/src/latexify/generate_latex_test.py index bc61b97..6964306 100644 --- a/src/latexify/generate_latex_test.py +++ b/src/latexify/generate_latex_test.py @@ -54,6 +54,20 @@ def f(x): assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag +def test_get_latex_reduce_assignments_with_docstring() -> None: + def f(x): + """DocstringRemover is required.""" + y = 3 * x + return y + + latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}" + latex_with_flag = r"f(x) = 3 x" + + assert generate_latex.get_latex(f) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag + assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag + + def test_get_latex_reduce_assignments_with_aug_assign() -> None: def f(x): y = 3 diff --git a/src/latexify/transformers/__init__.py b/src/latexify/transformers/__init__.py index 79c9e21..920fcf9 100644 --- a/src/latexify/transformers/__init__.py +++ b/src/latexify/transformers/__init__.py @@ -2,6 +2,7 @@ from latexify.transformers.assignment_reducer import AssignmentReducer from latexify.transformers.aug_assign_replacer import AugAssignReplacer +from latexify.transformers.docstring_remover import DocstringRemover from latexify.transformers.function_expander import FunctionExpander from latexify.transformers.identifier_replacer import IdentifierReplacer from latexify.transformers.prefix_trimmer import PrefixTrimmer @@ -9,6 +10,7 @@ __all__ = [ "AssignmentReducer", "AugAssignReplacer", + "DocstringRemover", "FunctionExpander", "IdentifierReplacer", "PrefixTrimmer", diff --git a/src/latexify/transformers/docstring_remover.py b/src/latexify/transformers/docstring_remover.py new file mode 100644 index 0000000..e94d23d --- /dev/null +++ b/src/latexify/transformers/docstring_remover.py @@ -0,0 +1,20 @@ +"""Transformer to remove all docstrings.""" + +from __future__ import annotations + +import ast +from typing import Union + +from latexify import ast_utils + + +class DocstringRemover(ast.NodeTransformer): + """NodeTransformer to remove all docstrings. + + Docstrings here are detected as Expr nodes with a single string constant. + """ + + def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]: + if ast_utils.is_str(node.value): + return None + return node diff --git a/src/latexify/transformers/docstring_remover_test.py b/src/latexify/transformers/docstring_remover_test.py new file mode 100644 index 0000000..d65c524 --- /dev/null +++ b/src/latexify/transformers/docstring_remover_test.py @@ -0,0 +1,32 @@ +"""Tests for latexify.transformers.docstring_remover.""" + +import ast + +from latexify import ast_utils, parser, test_utils +from latexify.transformers.docstring_remover import DocstringRemover + + +def test_remove_docstrings() -> None: + def f(): + """Test docstring.""" + x = 42 + f() # This Expr should not be removed. + """This string constant should also be removed.""" + return x + + tree = parser.parse_function(f).body[0] + assert isinstance(tree, ast.FunctionDef) + + expected = ast.FunctionDef( + name="f", + body=[ + ast.Assign( + targets=[ast.Name(id="x", ctx=ast.Store())], + value=ast_utils.make_constant(42), + ), + ast.Expr(value=ast.Call(func=ast.Name(id="f", ctx=ast.Load()))), + ast.Return(value=ast.Name(id="x", ctx=ast.Load())), + ], + ) + transformed = DocstringRemover().visit(tree) + test_utils.assert_ast_equal(transformed, expected)