Skip to content

Commit

Permalink
Add AugAssign support
Browse files Browse the repository at this point in the history
  • Loading branch information
odashi committed Nov 16, 2023
1 parent f68638a commit c3cebed
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/latexify/generate_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def get_latex(
# Obtains the source AST.
tree = parser.parse_function(fn)

# Applies AST transformations.
# Mandatory AST Transformation.
tree = transformers.AugAssignReplacer().visit(tree)

# Conditional AST transformation.
if merged_config.prefixes is not None:
tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree)
if merged_config.identifiers is not None:
Expand Down
14 changes: 14 additions & 0 deletions src/latexify/generate_latex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def f(x):
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_aug_assign() -> None:
def f(x):
y = 3
y *= x
return y

latex_without_flag = r"\begin{array}{l} y = 3 \\ y = y x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"

assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_use_math_symbols() -> None:
def f(alpha):
return alpha
Expand Down
2 changes: 2 additions & 0 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Package latexify.transformers."""

from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer

__all__ = [
"AssignmentReducer",
"AugAssignReplacer",
"FunctionExpander",
"IdentifierReplacer",
"PrefixTrimmer",
Expand Down
20 changes: 20 additions & 0 deletions src/latexify/transformers/aug_assign_replacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Transformer to replace AugAssign to Assign."""

from __future__ import annotations

import ast


class AugAssignReplacer(ast.NodeTransformer):
"""NodeTransformer to replace AugAssign to corresponding Assign.
AugAssign(target, op, value) => Assign([target], BinOp(target, op, value))
"""

def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign:
left_args = {**vars(node.target), "ctx": ast.Load()}
left = type(node.target)(**left_args)
return ast.Assign(
targets=[node.target], value=ast.BinOp(left, node.op, node.value)
)
24 changes: 24 additions & 0 deletions src/latexify/transformers/aug_assign_replacer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for latexify.transformers.aug_assign_replacer."""

import ast

from latexify import test_utils
from latexify.transformers.aug_assign_replacer import AugAssignReplacer


def test_replace() -> None:
tree = ast.AugAssign(
target=ast.Name(id="x", ctx=ast.Store()),
op=ast.Add(),
value=ast.Name(id="y", ctx=ast.Load()),
)
expected = ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast.BinOp(
left=ast.Name(id="x", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="y", ctx=ast.Load()),
),
)
transformed = AugAssignReplacer().visit(tree)
test_utils.assert_ast_equal(transformed, expected)

0 comments on commit c3cebed

Please sign in to comment.