Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AugAssign support #193

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/latexify/generate_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def get_latex(
# Obtains the source AST.
tree = parser.parse_function(fn)

# Applies AST transformations.
# Mandatory AST Transformation.
tree = transformers.AugAssignReplacer().visit(tree)

# Conditional AST transformation.
if merged_config.prefixes is not None:
tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree)
if merged_config.identifiers is not None:
Expand Down
14 changes: 14 additions & 0 deletions src/latexify/generate_latex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def f(x):
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_reduce_assignments_with_aug_assign() -> None:
def f(x):
y = 3
y *= x
return y

latex_without_flag = r"\begin{array}{l} y = 3 \\ y = y x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"

assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag


def test_get_latex_use_math_symbols() -> None:
def f(alpha):
return alpha
Expand Down
2 changes: 2 additions & 0 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Package latexify.transformers."""

from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer

__all__ = [
"AssignmentReducer",
"AugAssignReplacer",
"FunctionExpander",
"IdentifierReplacer",
"PrefixTrimmer",
Expand Down
20 changes: 20 additions & 0 deletions src/latexify/transformers/aug_assign_replacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Transformer to replace AugAssign to Assign."""

from __future__ import annotations

import ast


class AugAssignReplacer(ast.NodeTransformer):
"""NodeTransformer to replace AugAssign to corresponding Assign.

AugAssign(target, op, value) => Assign([target], BinOp(target, op, value))

"""

def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign:
left_args = {**vars(node.target), "ctx": ast.Load()}
left = type(node.target)(**left_args)
return ast.Assign(
targets=[node.target], value=ast.BinOp(left, node.op, node.value)
)
24 changes: 24 additions & 0 deletions src/latexify/transformers/aug_assign_replacer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for latexify.transformers.aug_assign_replacer."""

import ast

from latexify import test_utils
from latexify.transformers.aug_assign_replacer import AugAssignReplacer


def test_replace() -> None:
tree = ast.AugAssign(
target=ast.Name(id="x", ctx=ast.Store()),
op=ast.Add(),
value=ast.Name(id="y", ctx=ast.Load()),
)
expected = ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast.BinOp(
left=ast.Name(id="x", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="y", ctx=ast.Load()),
),
)
transformed = AugAssignReplacer().visit(tree)
test_utils.assert_ast_equal(transformed, expected)