|
22 | 22 | from typing import List |
23 | 23 |
|
24 | 24 | from pynestml.cocos.co_co import CoCo |
| 25 | +from pynestml.meta_model.ast_declaration import ASTDeclaration |
| 26 | +from pynestml.meta_model.ast_external_variable import ASTExternalVariable |
25 | 27 | from pynestml.meta_model.ast_function_call import ASTFunctionCall |
26 | 28 | from pynestml.meta_model.ast_kernel import ASTKernel |
27 | 29 | from pynestml.meta_model.ast_model import ASTModel |
28 | 30 | from pynestml.meta_model.ast_node import ASTNode |
29 | 31 | from pynestml.meta_model.ast_variable import ASTVariable |
| 32 | +from pynestml.symbols.predefined_functions import PredefinedFunctions |
30 | 33 | from pynestml.symbols.symbol import SymbolKind |
31 | 34 | from pynestml.utils.logger import Logger, LoggingLevel |
32 | 35 | from pynestml.utils.messages import Messages |
@@ -89,24 +92,45 @@ def visit_variable(self, node: ASTNode): |
89 | 92 | if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()): |
90 | 93 | code, message = Messages.get_no_variable_found(kernelName) |
91 | 94 | Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR) |
| 95 | + |
92 | 96 | continue |
| 97 | + |
93 | 98 | if not symbol.is_kernel(): |
94 | 99 | continue |
| 100 | + |
95 | 101 | if node.get_complete_name() == kernelName: |
96 | | - parent = node.get_parent() |
97 | | - if parent is not None: |
| 102 | + parent = node |
| 103 | + correct = False |
| 104 | + while parent is not None and not isinstance(parent, ASTModel): |
| 105 | + parent = parent.get_parent() |
| 106 | + assert parent is not None |
| 107 | + |
| 108 | + if isinstance(parent, ASTDeclaration): |
| 109 | + for lhs_var in parent.get_variables(): |
| 110 | + print("Cjecking " + kernelName + " stwit " + lhs_var.get_complete_name()) |
| 111 | + if kernelName == lhs_var.get_complete_name(): |
| 112 | + # kernel name appears on lhs of declaration, assume it is initial state |
| 113 | + correct = True |
| 114 | + parent = None # break out of outer loop |
| 115 | + break |
| 116 | + |
98 | 117 | if isinstance(parent, ASTKernel): |
99 | | - continue |
100 | | - grandparent = parent.get_parent() |
101 | | - if grandparent is not None and isinstance(grandparent, ASTFunctionCall): |
102 | | - grandparent_func_name = grandparent.get_name() |
103 | | - if grandparent_func_name == 'convolve': |
104 | | - continue |
105 | | - code, message = Messages.get_kernel_outside_convolve(kernelName) |
106 | | - Logger.log_message(code=code, |
107 | | - message=message, |
108 | | - log_level=LoggingLevel.ERROR, |
109 | | - error_position=node.get_source_position()) |
| 118 | + # kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'`` |
| 119 | + correct = True |
| 120 | + break |
| 121 | + |
| 122 | + if isinstance(parent, ASTFunctionCall): |
| 123 | + func_name = parent.get_name() |
| 124 | + if func_name == PredefinedFunctions.CONVOLVE: |
| 125 | + # kernel name is used inside convolve call |
| 126 | + correct = True |
| 127 | + |
| 128 | + if not correct: |
| 129 | + code, message = Messages.get_kernel_outside_convolve(kernelName) |
| 130 | + Logger.log_message(code=code, |
| 131 | + message=message, |
| 132 | + log_level=LoggingLevel.ERROR, |
| 133 | + error_position=node.get_source_position()) |
110 | 134 |
|
111 | 135 |
|
112 | 136 | class KernelCollectingVisitor(ASTVisitor): |
|
0 commit comments