Skip to content

Commit

Permalink
used _match_subject_stack as suggested
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuqi07 committed Dec 15, 2022
1 parent fd2ddee commit 0daeb72
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
use_math_symbols: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
_match_subject_stack: list[str] = [],
) -> None:
"""Initializer.
Expand All @@ -34,6 +35,7 @@ def __init__(
use_signature: Whether to add the function signature before the expression
or not.
use_set_symbols: Whether to use set symbols or not.
_match_subject_stack: a stack of subject names that are used in match
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
Expand All @@ -42,6 +44,7 @@ def __init__(
use_math_symbols=use_math_symbols
)
self._use_signature = use_signature
self._match_subject_stack = _match_subject_stack

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
Expand Down Expand Up @@ -141,6 +144,9 @@ def visit_If(self, node: ast.If) -> str:

def visit_Match(self, node: ast.Match) -> str:
"""Visit a Match node."""
subject_latex = self._expression_codegen.visit(node.subject)
self._match_subject_stack.append(subject_latex)

if not (
len(node.cases) >= 2
and isinstance(node.cases[-1].pattern, ast.MatchAs)
Expand All @@ -162,8 +168,6 @@ def visit_Match(self, node: ast.Match) -> str:
if i < len(node.cases) - 1:
body_latex = self.visit(case.body[0])
cond_latex = self.visit(case.pattern)
# if case.guard is not None:
# cond_latex = self._expression_codegen.visit(case.guard)

case_latexes.append(body_latex + r", & \mathrm{if} \ " + cond_latex)
else:
Expand All @@ -177,21 +181,14 @@ def visit_Match(self, node: ast.Match) -> str:
+ r" \end{array} \right."
)

latex_final = latex.replace("subject_name", subject_latex)
return latex_final
self._match_subject_stack.pop()
return latex

def visit_MatchValue(self, node: ast.MatchValue) -> str:
"""Visit a MatchValue node."""
latex = self._expression_codegen.visit(node.value)
return "subject_name = " + latex
return self._match_subject_stack[-1] + " = " + latex

def visit_MatchOr(self, node: ast.MatchOr) -> str:
"""Visit a MatchOr node."""
# case_latexes = []
# for i, pattern in enumerate(node.patterns):
# if i == 0:
# case_latexes.append(self.visit(pattern))
# else:
# case_latexes.append(r" \lor " + self.visit(pattern))
# return "".join(case_latexes)
return r" \lor ".join(self.visit(p) for p in node.patterns)

0 comments on commit 0daeb72

Please sign in to comment.