Skip to content

Commit

Permalink
fix wrappign around pow (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
odashi authored Nov 21, 2023
1 parent 1cfb8e7 commit 0cba4c9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,16 +408,22 @@ def visit_Call(self, node: ast.Call) -> str:

if rule.is_unary and len(node.args) == 1:
# Unary function. Applies the same wrapping policy with the unary operators.
precedence = expression_rules.get_precedence(node)
arg = node.args[0]
# NOTE(odashi):
# Factorial "x!" is treated as a special case: it requires both inner/outer
# parentheses for correct interpretation.
precedence = expression_rules.get_precedence(node)
arg = node.args[0]
force_wrap = isinstance(arg, ast.Call) and (
force_wrap_factorial = isinstance(arg, ast.Call) and (
func_name == "factorial"
or ast_utils.extract_function_name_or_none(arg) == "factorial"
)
arg_latex = self._wrap_operand(arg, precedence, force_wrap)
# Note(odashi):
# Wrapping is also required if the argument is pow.
# https://github.com/google/latexify_py/issues/189
force_wrap_pow = isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Pow)
arg_latex = self._wrap_operand(
arg, precedence, force_wrap_factorial or force_wrap_pow
)
elements = [rule.left, arg_latex, rule.right]
else:
arg_latex = ", ".join(self.visit(arg) for arg in node.args)
Expand Down Expand Up @@ -490,7 +496,7 @@ def _wrap_operand(
latex = self.visit(child)
child_prec = expression_rules.get_precedence(child)

if child_prec < parent_prec or force_wrap and child_prec == parent_prec:
if force_wrap or child_prec < parent_prec:
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"

return latex
Expand Down
19 changes: 19 additions & 0 deletions src/latexify/codegen/expression_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,25 @@ def test_visit_call(code: str, latex: str) -> None:
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
("log(x)**2", r"\mathopen{}\left( \log x \mathclose{}\right)^{2}"),
("log(x**2)", r"\log \mathopen{}\left( x^{2} \mathclose{}\right)"),
(
"log(x**2)**3",
r"\mathopen{}\left("
r" \log \mathopen{}\left( x^{2} \mathclose{}\right)"
r" \mathclose{}\right)^{3}",
),
],
)
def test_visit_call_with_pow(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, (ast.Call, ast.BinOp))
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
"src_suffix,dest_suffix",
[
Expand Down

0 comments on commit 0cba4c9

Please sign in to comment.