From 0daeb72147ba4f7ba8b12d0168b886b6cacb6928 Mon Sep 17 00:00:00 2001 From: Yuqi Gong Date: Wed, 14 Dec 2022 20:20:37 -0500 Subject: [PATCH] used _match_subject_stack as suggested --- src/latexify/codegen/function_codegen.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index e77afe3..b0f2424 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -25,6 +25,7 @@ def __init__( use_math_symbols: bool = False, use_signature: bool = True, use_set_symbols: bool = False, + _match_subject_stack: list[str] = [], ) -> None: """Initializer. @@ -34,6 +35,7 @@ def __init__( use_signature: Whether to add the function signature before the expression or not. use_set_symbols: Whether to use set symbols or not. + _match_subject_stack: a stack of subject names that are used in match """ self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols @@ -42,6 +44,7 @@ def __init__( use_math_symbols=use_math_symbols ) self._use_signature = use_signature + self._match_subject_stack = _match_subject_stack def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( @@ -141,6 +144,9 @@ def visit_If(self, node: ast.If) -> str: def visit_Match(self, node: ast.Match) -> str: """Visit a Match node.""" + subject_latex = self._expression_codegen.visit(node.subject) + self._match_subject_stack.append(subject_latex) + if not ( len(node.cases) >= 2 and isinstance(node.cases[-1].pattern, ast.MatchAs) @@ -162,8 +168,6 @@ def visit_Match(self, node: ast.Match) -> str: if i < len(node.cases) - 1: body_latex = self.visit(case.body[0]) cond_latex = self.visit(case.pattern) - # if case.guard is not None: - # cond_latex = self._expression_codegen.visit(case.guard) case_latexes.append(body_latex + r", & \mathrm{if} \ " + cond_latex) else: @@ -177,21 +181,14 @@ def visit_Match(self, node: ast.Match) -> str: + r" \end{array} \right." ) - latex_final = latex.replace("subject_name", subject_latex) - return latex_final + self._match_subject_stack.pop() + return latex def visit_MatchValue(self, node: ast.MatchValue) -> str: """Visit a MatchValue node.""" latex = self._expression_codegen.visit(node.value) - return "subject_name = " + latex + return self._match_subject_stack[-1] + " = " + latex def visit_MatchOr(self, node: ast.MatchOr) -> str: """Visit a MatchOr node.""" - # case_latexes = [] - # for i, pattern in enumerate(node.patterns): - # if i == 0: - # case_latexes.append(self.visit(pattern)) - # else: - # case_latexes.append(r" \lor " + self.visit(pattern)) - # return "".join(case_latexes) return r" \lor ".join(self.visit(p) for p in node.patterns)