Skip to content

Commit a03bf3f

Browse files
authored
Refactor NESTML printer (#1181)
1 parent 5258356 commit a03bf3f

12 files changed

+373
-180
lines changed

pynestml/codegeneration/printers/cpp_variable_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def _print_cpp_name(cls, variable_name: str) -> str:
3434
:param variable_name: a single name.
3535
:return: a string representation
3636
"""
37-
differential_order = variable_name.count("\"")
37+
differential_order = variable_name.count("'")
3838
if differential_order > 0:
39-
return variable_name.replace("\"", "").replace("$", "__DOLLAR") + "__" + "d" * differential_order
39+
return variable_name.replace("'", "").replace("$", "__DOLLAR") + "__" + "d" * differential_order
4040

4141
return variable_name.replace("$", "__DOLLAR")
4242

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# nestml_expression_printer.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
from pynestml.codegeneration.printers.expression_printer import ExpressionPrinter
23+
from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator
24+
from pynestml.meta_model.ast_comparison_operator import ASTComparisonOperator
25+
from pynestml.meta_model.ast_expression import ASTExpression
26+
from pynestml.meta_model.ast_logical_operator import ASTLogicalOperator
27+
from pynestml.meta_model.ast_node import ASTNode
28+
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
29+
from pynestml.meta_model.ast_unary_operator import ASTUnaryOperator
30+
31+
32+
class NESTMLExpressionPrinter(ExpressionPrinter):
33+
r"""
34+
Printer for ``ASTExpression`` nodes in NESTML syntax.
35+
"""
36+
37+
def print(self, node: ASTNode) -> str:
38+
if isinstance(node, ASTExpression):
39+
if node.get_implicit_conversion_factor() and not node.get_implicit_conversion_factor() == 1:
40+
return "(" + str(node.get_implicit_conversion_factor()) + " * (" + self.print_expression(node) + "))"
41+
42+
return self.print_expression(node)
43+
44+
if isinstance(node, ASTArithmeticOperator):
45+
return self.print_arithmetic_operator(node)
46+
47+
if isinstance(node, ASTUnaryOperator):
48+
return self.print_unary_operator(node)
49+
50+
if isinstance(node, ASTComparisonOperator):
51+
return self.print_comparison_operator(node)
52+
53+
if isinstance(node, ASTLogicalOperator):
54+
return self.print_logical_operator(node)
55+
56+
return self._simple_expression_printer.print(node)
57+
58+
def print_logical_operator(self, node: ASTLogicalOperator) -> str:
59+
if node.is_logical_and:
60+
return " and "
61+
62+
if node.is_logical_or:
63+
return " or "
64+
65+
raise Exception("Unknown logical operator")
66+
67+
def print_comparison_operator(self, node: ASTComparisonOperator) -> str:
68+
if node.is_lt:
69+
return " < "
70+
71+
if node.is_le:
72+
return " <= "
73+
74+
if node.is_eq:
75+
return " == "
76+
77+
if node.is_ne:
78+
return " != "
79+
80+
if node.is_ne2:
81+
return " <> "
82+
83+
if node.is_ge:
84+
return " >= "
85+
86+
if node.is_gt:
87+
return " > "
88+
89+
raise RuntimeError("Type of comparison operator not specified!")
90+
91+
def print_unary_operator(self, node: ASTUnaryOperator) -> str:
92+
if node.is_unary_plus:
93+
return "+"
94+
95+
if node.is_unary_minus:
96+
return "-"
97+
98+
if node.is_unary_tilde:
99+
return "~"
100+
101+
raise RuntimeError("Type of unary operator not specified!")
102+
103+
def print_arithmetic_operator(self, node: ASTArithmeticOperator) -> str:
104+
if node.is_times_op:
105+
return " * "
106+
107+
if node.is_div_op:
108+
return " / "
109+
110+
if node.is_modulo_op:
111+
return " % "
112+
113+
if node.is_plus_op:
114+
return " + "
115+
116+
if node.is_minus_op:
117+
return " - "
118+
119+
if node.is_pow_op:
120+
return " ** "
121+
122+
raise RuntimeError("Arithmetic operator not specified.")
123+
124+
def print_expression(self, node: ASTExpression) -> str:
125+
ret = ""
126+
if node.is_expression():
127+
if node.is_encapsulated:
128+
ret += "("
129+
130+
if node.is_logical_not:
131+
ret += "not "
132+
133+
if node.is_unary_operator():
134+
ret += self.print_unary_operator(node.get_unary_operator())
135+
136+
if isinstance(node.get_expression(), ASTExpression):
137+
ret += self.print_expression(node.get_expression())
138+
elif isinstance(node.get_expression(), ASTSimpleExpression):
139+
ret += self._simple_expression_printer.print_simple_expression(node.get_expression())
140+
else:
141+
raise RuntimeError("Unknown node type")
142+
143+
if node.is_encapsulated:
144+
ret += ")"
145+
146+
elif node.is_compound_expression():
147+
if isinstance(node.get_lhs(), ASTExpression):
148+
ret += self.print_expression(node.get_lhs())
149+
elif isinstance(node.get_lhs(), ASTSimpleExpression):
150+
ret += self._simple_expression_printer.print_simple_expression(node.get_lhs())
151+
else:
152+
raise RuntimeError("Unknown node type")
153+
154+
ret += self.print(node.get_binary_operator())
155+
156+
if isinstance(node.get_rhs(), ASTExpression):
157+
ret += self.print_expression(node.get_rhs())
158+
elif isinstance(node.get_rhs(), ASTSimpleExpression):
159+
ret += self._simple_expression_printer.print_simple_expression(node.get_rhs())
160+
else:
161+
raise RuntimeError("Unknown node type")
162+
163+
elif node.is_ternary_operator():
164+
ret += self.print(node.get_condition()) + "?" + self.print(
165+
node.get_if_true()) + ":" + self.print(node.get_if_not())
166+
167+
return ret
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# nestml_function_call_printer.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
from pynestml.codegeneration.printers.function_call_printer import FunctionCallPrinter
23+
from pynestml.meta_model.ast_function_call import ASTFunctionCall
24+
25+
26+
class NESTMLFunctionCallPrinter(FunctionCallPrinter):
27+
r"""
28+
Printer for ASTFunctionCall in C++ syntax.
29+
"""
30+
31+
def print_function_call(self, node: ASTFunctionCall) -> str:
32+
ret = str(node.get_name()) + "("
33+
for i in range(0, len(node.get_args())):
34+
ret += self._expression_printer.print(node.get_args()[i])
35+
if i < len(node.get_args()) - 1: # in the case that it is not the last arg, print also a comma
36+
ret += ","
37+
38+
ret += ")"
39+
40+
return ret

0 commit comments

Comments
 (0)