Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-picked a handful of intrinsic related commits out of multi_sdfg branch. #1728

Merged
merged 20 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self, funcs=None):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
self.excepted_funcs = [
"malloc", "exp", "pow", "sqrt", "cbrt", "max", "abs", "min", "__dace_sign", "tanh",
"malloc", "pow", "cbrt", "__dace_sign", "tanh", "atan2",
"__dace_epsilon", *FortranIntrinsics.function_names()
]

Expand Down Expand Up @@ -220,7 +220,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if not stop and node.name.name not in [
"malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
"malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
]:
self.nodes.append(node)
return self.generic_visit(node)
Expand All @@ -241,7 +241,7 @@ def __init__(self, count=0):
def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]:
if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]:
return self.generic_visit(node)
if hasattr(node, "subroutine"):
if node.subroutine is True:
Expand All @@ -251,6 +251,11 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
else:
self.count = self.count + 1
tmp = self.count

for i, arg in enumerate(node.args):
# Ensure we allow to extract function calls from arguments
node.args[i] = self.visit(arg)

return ast_internal_classes.Name_Node(name="tmp_call_" + str(tmp - 1))

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
Expand All @@ -263,9 +268,13 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
for i in res:
if i == child:
res.pop(res.index(i))
temp = self.count
if res is not None:
for i in range(0, len(res)):
# Variables are counted from 0...end, starting from main node, to all calls nested
# in main node arguments.
# However, we need to define nested ones first.
# We go in reverse order, counting from end-1 to 0.
temp = self.count + len(res) - 1
for i in reversed(range(0, len(res))):

newbody.append(
ast_internal_classes.Decl_Stmt_Node(vardecl=[
Expand All @@ -282,7 +291,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
type=res[i].type),
rval=res[i],
line_number=child.line_number))
temp = temp + 1
temp = temp - 1
if isinstance(child, ast_internal_classes.Call_Expr_Node):
new_args = []
if hasattr(child, "args"):
Expand Down Expand Up @@ -368,7 +377,8 @@ def __init__(self):
self.nodes: List[ast_internal_classes.Array_Subscript_Node] = []

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]:
return self.generic_visit(node)
else:
return
Expand Down Expand Up @@ -401,7 +411,8 @@ def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = Fa
self.scope_vars.visit(ast)

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]:
return self.generic_visit(node)
else:
return node
Expand Down
12 changes: 8 additions & 4 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con
calls.visit(node)
if len(calls.nodes) == 1:
augmented_call = calls.nodes[0]
if augmented_call.name.name not in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh", "__dace_epsilon"]:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", *FortranIntrinsics.retained_function_names()]:
augmented_call.args.append(node.lval)
augmented_call.hasret = True
self.call2sdfg(augmented_call, sdfg, cfg)
Expand Down Expand Up @@ -1090,7 +1091,8 @@ def create_ast_from_string(
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics().transformations():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
Expand Down Expand Up @@ -1126,7 +1128,8 @@ def create_sdfg_from_string(
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics().transformations():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
Expand Down Expand Up @@ -1172,7 +1175,8 @@ def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_block
program = ast_transforms.ArrayToLoop(program).visit(program)

for transformation in own_ast.fortran_intrinsics():
program = transformation(program).visit(program)
transformation.initialize(program)
program = transformation.visit(program)

program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program).visit(program)
Expand Down
Loading
Loading