Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DocstringRemover #197

Merged
merged 2 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/latexify/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ def is_constant(node: ast.AST) -> bool:
return isinstance(node, ast.Constant)


def is_str(node: ast.AST) -> bool:
"""Checks if the node is a str constant.

Args:
node: The node to examine.

Returns:
True if the node is a str constant, False otherwise.
"""
if sys.version_info.minor < 8 and isinstance(node, ast.Str):
return True

return isinstance(node, ast.Constant) and isinstance(node.value, str)


def extract_int_or_none(node: ast.expr) -> int | None:
"""Extracts int constant from the given Constant node.

Expand Down
32 changes: 32 additions & 0 deletions src/latexify/ast_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def test_is_constant(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_constant(value) is expected


@test_utils.require_at_most(7)
@pytest.mark.parametrize(
"value,expected",
[
(ast.Bytes(s=b"foo"), False),
(ast.Constant("bar"), True),
(ast.Ellipsis(), False),
(ast.NameConstant(value=None), False),
(ast.Num(n=123), False),
(ast.Str(s="baz"), True),
(ast.Expr(value=ast.Num(456)), False),
(ast.Global(names=["qux"]), False),
],
)
def test_is_str_legacy(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_str(value) is expected


@test_utils.require_at_least(8)
@pytest.mark.parametrize(
"value,expected",
[
(ast.Constant(value=123), False),
(ast.Constant(value="foo"), True),
(ast.Expr(value=ast.Constant(value="foo")), False),
(ast.Global(names=["foo"]), False),
],
)
def test_is_str(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_str(value) is expected


def test_extract_int_or_none() -> None:
assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123
assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0
Expand Down
1 change: 1 addition & 0 deletions src/latexify/generate_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_latex(
if merged_config.identifiers is not None:
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
if merged_config.reduce_assignments:
tree = transformers.DocstringRemover().visit(tree)
tree = transformers.AssignmentReducer().visit(tree)
if merged_config.expand_functions is not None:
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)
Expand Down
14 changes: 14 additions & 0 deletions src/latexify/generate_latex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def f(x):
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_docstring() -> None:
def f(x):
"""DocstringRemover is required."""
y = 3 * x
return y

latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"

assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_aug_assign() -> None:
def f(x):
y = 3
Expand Down
2 changes: 2 additions & 0 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
from latexify.transformers.docstring_remover import DocstringRemover
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer

__all__ = [
"AssignmentReducer",
"AugAssignReplacer",
"DocstringRemover",
"FunctionExpander",
"IdentifierReplacer",
"PrefixTrimmer",
Expand Down
20 changes: 20 additions & 0 deletions src/latexify/transformers/docstring_remover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Transformer to remove all docstrings."""

from __future__ import annotations

import ast
from typing import Union

from latexify import ast_utils


class DocstringRemover(ast.NodeTransformer):
"""NodeTransformer to remove all docstrings.

Docstrings here are detected as Expr nodes with a single string constant.
"""

def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]:
if ast_utils.is_str(node.value):
return None
return node
32 changes: 32 additions & 0 deletions src/latexify/transformers/docstring_remover_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Tests for latexify.transformers.docstring_remover."""

import ast

from latexify import ast_utils, parser, test_utils
from latexify.transformers.docstring_remover import DocstringRemover


def test_remove_docstrings() -> None:
def f():
"""Test docstring."""
x = 42
f() # This Expr should not be removed.
"""This string constant should also be removed."""
return x

tree = parser.parse_function(f).body[0]
assert isinstance(tree, ast.FunctionDef)

expected = ast.FunctionDef(
name="f",
body=[
ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast_utils.make_constant(42),
),
ast.Expr(value=ast.Call(func=ast.Name(id="f", ctx=ast.Load()))),
ast.Return(value=ast.Name(id="x", ctx=ast.Load())),
],
)
transformed = DocstringRemover().visit(tree)
test_utils.assert_ast_equal(transformed, expected)