From dce94211af7e184aed7b33c128bf0e5490cc3093 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 08:08:47 +0100 Subject: [PATCH 01/12] Import frontend patches from multi_sdfg --- dace/frontend/fortran/ast_components.py | 1267 ++++++- dace/frontend/fortran/ast_desugaring.py | 2098 +++++++++++ dace/frontend/fortran/ast_internal_classes.py | 626 +++- dace/frontend/fortran/ast_transforms.py | 2999 ++++++++++++++-- dace/frontend/fortran/ast_utils.py | 467 ++- dace/frontend/fortran/fortran_parser.py | 3182 +++++++++++++++-- .../fortran/icon_config_propagation.py | 230 ++ dace/frontend/fortran/intrinsics.py | 762 +++- 8 files changed, 10540 insertions(+), 1091 deletions(-) create mode 100644 dace/frontend/fortran/ast_desugaring.py create mode 100644 dace/frontend/fortran/icon_config_propagation.py diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index ab0aa9c777..294f490d39 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -1,17 +1,22 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.two import Fortran2008 as f08 +from typing import Any, List, Optional, Type, TypeVar, Union, overload, TYPE_CHECKING, Dict + +import networkx as nx from fparser.two import Fortran2003 as f03 -from fparser.two import symbol_table +from fparser.two import Fortran2008 as f08 +from fparser.two.Fortran2003 import Function_Subprogram, Function_Stmt, Prefix, Intrinsic_Type_Spec, \ + Assignment_Stmt, Logical_Literal_Constant, Real_Literal_Constant, Signed_Real_Literal_Constant, \ + Int_Literal_Constant, Signed_Int_Literal_Constant, Hex_Constant, Function_Reference -import copy from dace.frontend.fortran import ast_internal_classes -from dace.frontend.fortran.ast_internal_classes import FNode, Name_Node -from typing import Any, List, Tuple, Type, TypeVar, Union, overload, TYPE_CHECKING +from dace.frontend.fortran.ast_internal_classes import Name_Node, Program_Node, Decl_Stmt_Node, Var_Decl_Node +from dace.frontend.fortran.ast_transforms import StructLister, StructDependencyLister, Structures +from dace.frontend.fortran.ast_utils import singular if TYPE_CHECKING: from dace.frontend.fortran.intrinsics import FortranIntrinsics -#We rely on fparser to provide an initial AST and convert to a version that is more suitable for our purposes +# We rely on fparser to provide an initial AST and convert to a version that is more suitable for our purposes # The following class is used to translate the fparser AST to our own AST of Fortran # the supported_fortran_types dictionary is used to determine which types are supported by our compiler @@ -50,6 +55,8 @@ def get_child(node: Union[FASTNode, List[FASTNode]], child_type: Union[str, Type if len(children_of_type) == 1: return children_of_type[0] + # Temporary workaround to allow feature list to be generated + return None raise ValueError('Expected only one child of type {} but found {}'.format(child_type, children_of_type)) @@ -104,26 +111,31 @@ class InternalFortranAst: for each entry in the dictionary, the key is the name of the class in the fparser AST and the value is the name of the function that will be used to translate the fparser AST to our AST """ - def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): + + def __init__(self): """ Initialization of the AST converter - :param ast: the fparser AST - :param tables: the symbol table of the fparser AST - """ - self.ast = ast - self.tables = tables + self.to_parse_list = {} + self.unsupported_fortran_syntax = {} + self.current_ast = None self.functions_and_subroutines = [] self.symbols = {} + self.intrinsics_list = [] + self.placeholders = {} + self.placeholders_offsets = {} self.types = { - "LOGICAL": "BOOL", + "LOGICAL": "LOGICAL", "CHARACTER": "CHAR", "INTEGER": "INTEGER", "INTEGER4": "INTEGER", + "INTEGER8": "INTEGER8", "REAL4": "REAL", "REAL8": "DOUBLE", "DOUBLE PRECISION": "DOUBLE", "REAL": "REAL", + "CLASS": "CLASS", + "Unknown": "REAL", } from dace.frontend.fortran.intrinsics import FortranIntrinsics self.intrinsic_handler = FortranIntrinsics() @@ -136,10 +148,14 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "End_Program_Stmt": self.end_program_stmt, "Subroutine_Subprogram": self.subroutine_subprogram, "Function_Subprogram": self.function_subprogram, + "Module_Subprogram_Part": self.module_subprogram_part, + "Internal_Subprogram_Part": self.internal_subprogram_part, "Subroutine_Stmt": self.subroutine_stmt, "Function_Stmt": self.function_stmt, + "Prefix": self.prefix_stmt, "End_Subroutine_Stmt": self.end_subroutine_stmt, "End_Function_Stmt": self.end_function_stmt, + "Rename": self.rename, "Module": self.module, "Module_Stmt": self.module_stmt, "End_Module_Stmt": self.end_module_stmt, @@ -158,6 +174,8 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Loop_Control": self.loop_control, "Block_Nonlabel_Do_Construct": self.block_nonlabel_do_construct, "Real_Literal_Constant": self.real_literal_constant, + "Signed_Real_Literal_Constant": self.real_literal_constant, + "Char_Literal_Constant": self.char_literal_constant, "Subscript_Triplet": self.subscript_triplet, "Section_Subscript_List": self.section_subscript_list, "Explicit_Shape_Spec_List": self.explicit_shape_spec_list, @@ -166,6 +184,7 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Attr_Spec": self.attr_spec, "Intent_Spec": self.intent_spec, "Access_Spec": self.access_spec, + "Access_Stmt": self.access_stmt, "Allocatable_Stmt": self.allocatable_stmt, "Asynchronous_Stmt": self.asynchronous_stmt, "Bind_Stmt": self.bind_stmt, @@ -189,7 +208,6 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Assignment_Stmt": self.assignment_stmt, "Pointer_Assignment_Stmt": self.pointer_assignment_stmt, "Where_Stmt": self.where_stmt, - "Forall_Stmt": self.forall_stmt, "Where_Construct": self.where_construct, "Where_Construct_Stmt": self.where_construct_stmt, "Masked_Elsewhere_Stmt": self.masked_elsewhere_stmt, @@ -217,6 +235,8 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "End_Do_Stmt": self.end_do_stmt, "Interface_Block": self.interface_block, "Interface_Stmt": self.interface_stmt, + "Procedure_Name_List": self.procedure_name_list, + "Procedure_Stmt": self.procedure_stmt, "End_Interface_Stmt": self.end_interface_stmt, "Generic_Spec": self.generic_spec, "Name": self.name, @@ -225,12 +245,16 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Intrinsic_Type_Spec": self.intrinsic_type_spec, "Entity_Decl_List": self.entity_decl_list, "Int_Literal_Constant": self.int_literal_constant, + "Signed_Int_Literal_Constant": self.int_literal_constant, + "Hex_Constant": self.hex_constant, "Logical_Literal_Constant": self.logical_literal_constant, "Actual_Arg_Spec_List": self.actual_arg_spec_list, + "Actual_Arg_Spec": self.actual_arg_spec, "Attr_Spec_List": self.attr_spec_list, "Initialization": self.initialization, "Procedure_Declaration_Stmt": self.procedure_declaration_stmt, "Type_Bound_Procedure_Part": self.type_bound_procedure_part, + "Data_Pointer_Object": self.data_pointer_object, "Contains_Stmt": self.contains_stmt, "Call_Stmt": self.call_stmt, "Return_Stmt": self.return_stmt, @@ -241,6 +265,7 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Equiv_Operand": self.level_2_expr, "Level_3_Expr": self.level_2_expr, "Level_4_Expr": self.level_2_expr, + "Level_5_Expr": self.level_2_expr, "Add_Operand": self.level_2_expr, "Or_Operand": self.level_2_expr, "And_Operand": self.level_2_expr, @@ -248,6 +273,7 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Mult_Operand": self.power_expr, "Parenthesis": self.parenthesis_expr, "Intrinsic_Name": self.intrinsic_handler.replace_function_name, + "Suffix": self.suffix, "Intrinsic_Function_Reference": self.intrinsic_function_reference, "Only_List": self.only_list, "Structure_Constructor": self.structure_constructor, @@ -259,20 +285,79 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Allocation": self.allocation, "Allocate_Shape_Spec": self.allocate_shape_spec, "Allocate_Shape_Spec_List": self.allocate_shape_spec_list, + "Derived_Type_Def": self.derived_type_def, + "Derived_Type_Stmt": self.derived_type_stmt, + "Component_Part": self.component_part, + "Data_Component_Def_Stmt": self.data_component_def_stmt, + "End_Type_Stmt": self.end_type_stmt, + "Data_Ref": self.data_ref, + "Cycle_Stmt": self.cycle_stmt, + "Deferred_Shape_Spec": self.deferred_shape_spec, + "Deferred_Shape_Spec_List": self.deferred_shape_spec_list, + "Component_Initialization": self.component_initialization, + "Case_Selector": self.case_selector, + "Case_Value_Range_List": self.case_value_range_list, + "Procedure_Designator": self.procedure_designator, + "Specific_Binding": self.specific_binding, + "Enum_Def_Stmt": self.enum_def_stmt, + "Enumerator_Def_Stmt": self.enumerator_def_stmt, + "Enumerator_List": self.enumerator_list, + "Enumerator": self.enumerator, + "End_Enum_Stmt": self.end_enum_stmt, + "Exit_Stmt": self.exit_stmt, + "Enum_Def": self.enum_def, + "Connect_Spec": self.connect_spec, + "Namelist_Stmt": self.namelist_stmt, + "Namelist_Group_Object_List": self.namelist_group_object_list, + "Open_Stmt": self.open_stmt, + "Connect_Spec_List": self.connect_spec_list, + "Association": self.association, + "Association_List": self.association_list, + "Associate_Stmt": self.associate_stmt, + "End_Associate_Stmt": self.end_associate_stmt, + "Associate_Construct": self.associate_construct, + "Subroutine_Body": self.subroutine_body, + "Function_Reference": self.function_reference, + "Binding_Name_List": self.binding_name_list, + "Generic_Binding": self.generic_binding, + "Private_Components_Stmt": self.private_components_stmt, + "Stop_Code": self.stop_code, + "Error_Stop_Stmt": self.error_stop_stmt, + "Pointer_Object_List": self.pointer_object_list, + "Nullify_Stmt": self.nullify_stmt, + "Deallocate_Stmt": self.deallocate_stmt, + "Proc_Component_Ref": self.proc_component_ref, + "Component_Spec": self.component_spec, + "Allocate_Object_List": self.allocate_object_list, + "Read_Stmt": self.read_stmt, + "Close_Stmt": self.close_stmt, + "Io_Control_Spec": self.io_control_spec, + "Io_Control_Spec_List": self.io_control_spec_list, + "Close_Spec_List": self.close_spec_list, + "Close_Spec": self.close_spec, + + # "Component_Decl_List": self.component_decl_list, + # "Component_Decl": self.component_decl, } + self.type_arbitrary_array_variable_count = 0 def fortran_intrinsics(self) -> "FortranIntrinsics": return self.intrinsic_handler - def list_tables(self): - for i in self.tables._symbol_tables: - print(i) + def data_pointer_object(self, node: FASTNode): + children = self.create_children(node) + if node.children[1] == "%": + return ast_internal_classes.Data_Ref_Node(parent_ref=children[0], part_ref=children[2], type="VOID") + else: + raise NotImplementedError("Data pointer object not supported yet") def create_children(self, node: FASTNode): - return [self.create_ast(child) - for child in node] if isinstance(node, - (list, - tuple)) else [self.create_ast(child) for child in node.children] + return [self.create_ast(child) for child in node] \ + if isinstance(node, (list, tuple)) else [self.create_ast(child) for child in node.children] + + def cycle_stmt(self, node: FASTNode): + line = get_line(node) + return ast_internal_classes.Continue_Node(line_number=line) def create_ast(self, node=None): """ @@ -280,31 +365,309 @@ def create_ast(self, node=None): :param node: FASTNode :note: this is a recursive function, and relies on the dictionary of supported syntax to call the correct converter functions """ - if node is not None: - if isinstance(node, (list, tuple)): - return [self.create_ast(child) for child in node] - return self.supported_fortran_syntax[type(node).__name__](node) + if not node: + return None + if isinstance(node, (list, tuple)): + return [self.create_ast(child) for child in node] + if type(node).__name__ in self.supported_fortran_syntax: + handler = self.supported_fortran_syntax[type(node).__name__] + return handler(node) + + if type(node).__name__ == "Intrinsic_Name": + if node not in self.intrinsics_list: + self.intrinsics_list.append(node) + if self.unsupported_fortran_syntax.get(self.current_ast) is None: + self.unsupported_fortran_syntax[self.current_ast] = [] + if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]: + if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]: + self.unsupported_fortran_syntax[self.current_ast].append(type(node).__name__) + for i in node.children: + self.create_ast(i) + print("Unsupported syntax: ", type(node).__name__, node.string) + return None + + def finalize_ast(self, prog: Program_Node): + structs_lister = StructLister() + structs_lister.visit(prog) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + # print(struct_deps) + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + prog.structures = Structures(structs_lister.structs) + prog.placeholders = self.placeholders + prog.placeholders_offsets = self.placeholders_offsets + + def suffix(self, node: FASTNode): + children = self.create_children(node) + name = children[0] + return ast_internal_classes.Suffix_Node(name=name) + + def data_ref(self, node: FASTNode): + children = self.create_children(node) + idx = len(children) - 1 + parent = children[idx - 1] + part_ref = children[idx] + part_ref.isStructMember = True + # parent.isStructMember=True + idx = idx - 1 + current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=part_ref, type="VOID") + + while idx > 0: + parent = children[idx - 1] + current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=current, type="VOID") + idx = idx - 1 + return current + + def end_type_stmt(self, node: FASTNode): + return None + + def access_stmt(self, node: FASTNode): + return None + + def generic_binding(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Generic_Binding_Node(name=children[1], binding=children[2]) + + def private_components_stmt(self, node: FASTNode): return None + def deallocate_stmt(self, node: FASTNode): + children = self.create_children(node) + line = get_line(node) + return ast_internal_classes.Deallocate_Stmt_Node(list=children[0].list, line_number=line) + + def proc_component_ref(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Data_Ref_Node(parent_ref=children[0], part_ref=children[2], type="VOID") + + def component_spec(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Actual_Arg_Spec_Node(arg_name=children[0], arg=children[1], type="VOID") + + def allocate_object_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Allocate_Object_List_Node(list=children) + + def read_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Read_Stmt_Node(args=children[0], line_number=node.item.span) + + def close_stmt(self, node: FASTNode): + children = self.create_children(node) + if node.item is None: + line = '-1' + else: + line = node.item.span + return ast_internal_classes.Close_Stmt_Node(args=children[0], line_number=line) + + def io_control_spec(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.IO_Control_Spec_Node(name=children[0], args=children[1]) + + def io_control_spec_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.IO_Control_Spec_List_Node(list=children) + + def close_spec_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Close_Spec_List_Node(list=children) + + def close_spec(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Close_Spec_Node(name=children[0], args=children[1]) + + def stop_code(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Stop_Stmt_Node(code=node.string) + + def error_stop_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Error_Stmt_Node(error=children[1]) + + def pointer_object_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Pointer_Object_List_Node(list=children) + + def nullify_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Nullify_Stmt_Node(list=children[1].list) + + def binding_name_list(self, node: FASTNode): + children = self.create_children(node) + return children + + def connect_spec(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Connect_Spec_Node(type=children[0], args=children[1]) + + def connect_spec_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Connect_Spec_List_Node(list=children) + + def open_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Open_Stmt_Node(args=children[1].list, line_number=node.item.span) + + def namelist_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Namelist_Stmt_Node(name=children[0][0], list=children[0][1]) + + def namelist_group_object_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Namelist_Group_Object_List_Node(list=children) + + def associate_stmt(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Associate_Stmt_Node(args=children[1].list) + + def association(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Association_Node(name=children[0], expr=children[2]) + + def association_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Association_List_Node(list=children) + + def subroutine_body(self, node: FASTNode): + children = self.create_children(node) + return children + + def function_reference(self, node: Function_Reference): + name, args = self.create_children(node) + line = get_line(node) + return ast_internal_classes.Call_Expr_Node(name=name, + args=args.args if args else [], + type="VOID", subroutine=False, + line_number=line) + + def end_associate_stmt(self, node: FASTNode): + return None + + def associate_construct(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Associate_Construct_Node(associate=children[0], body=children[1]) + + def enum_def_stmt(self, node: FASTNode): + children = self.create_children(node) + return None + + def enumerator(self, node: FASTNode): + children = self.create_children(node) + return children + + def enumerator_def_stmt(self, node: FASTNode): + children = self.create_children(node) + return children[1] + + def enumerator_list(self, node: FASTNode): + children = self.create_children(node) + return children + + def end_enum_stmt(self, node: FASTNode): + return None + + def enum_def(self, node: FASTNode): + children = self.create_children(node) + return children[1:-1] + + def exit_stmt(self, node: FASTNode): + line = get_line(node) + return ast_internal_classes.Exit_Node(line_number=line) + + def deferred_shape_spec(self, node: FASTNode): + return ast_internal_classes.Defer_Shape_Node() + + def deferred_shape_spec_list(self, node: FASTNode): + children = self.create_children(node) + return children + + def component_initialization(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Component_Initialization_Node(init=children[1]) + + def procedure_designator(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Procedure_Separator_Node(parent_ref=children[0], part_ref=children[2]) + + def derived_type_def(self, node: FASTNode): + children = self.create_children(node) + name = children[0].name + component_part = get_child(children, ast_internal_classes.Component_Part_Node) + procedure_part = get_child(children, ast_internal_classes.Bound_Procedures_Node) + from dace.frontend.fortran.ast_transforms import PartialRenameVar + if component_part is not None: + component_part = PartialRenameVar(oldname="__f2dace_A", newname="__f2dace_SA").visit(component_part) + component_part = PartialRenameVar(oldname="__f2dace_OA", newname="__f2dace_SOA").visit(component_part) + new_placeholder = {} + new_placeholder_offsets = {} + for k, v in self.placeholders.items(): + if "__f2dace_A" in k: + new_placeholder[k.replace("__f2dace_A", "__f2dace_SA")] = self.placeholders[k] + else: + new_placeholder[k] = self.placeholders[k] + self.placeholders = new_placeholder + for k, v in self.placeholders_offsets.items(): + if "__f2dace_OA" in k: + new_placeholder_offsets[k.replace("__f2dace_OA", "__f2dace_SOA")] = self.placeholders_offsets[k] + else: + new_placeholder_offsets[k] = self.placeholders_offsets[k] + self.placeholders_offsets = new_placeholder_offsets + return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part, + procedure_part=procedure_part) + + def derived_type_stmt(self, node: FASTNode): + children = self.create_children(node) + name = get_child(children, ast_internal_classes.Type_Name_Node) + return ast_internal_classes.Derived_Type_Stmt_Node(name=name) + + def component_part(self, node: FASTNode): + children = self.create_children(node) + component_def_stmts = [i for i in children if isinstance(i, ast_internal_classes.Data_Component_Def_Stmt_Node)] + return ast_internal_classes.Component_Part_Node(component_def_stmts=component_def_stmts) + + def data_component_def_stmt(self, node: FASTNode): + children = self.type_declaration_stmt(node) + return ast_internal_classes.Data_Component_Def_Stmt_Node(vars=children) + + def component_decl_list(self, node: FASTNode): + children = self.create_children(node) + component_decls = [i for i in children if isinstance(i, ast_internal_classes.Component_Decl_Node)] + return ast_internal_classes.Component_Decl_List_Node(component_decls=component_decls) + + def component_decl(self, node: FASTNode): + children = self.create_children(node) + name = get_child(children, ast_internal_classes.Name_Node) + return ast_internal_classes.Component_Decl_Node(name=name) + def write_stmt(self, node: FASTNode): - children = self.create_children(node.children[1]) + # children=[] + # if node.children[0] is not None: + # children = self.create_children(node.children[0]) + # if node.children[1] is not None: + # children = self.create_children(node.children[1]) line = get_line(node) - return ast_internal_classes.Write_Stmt_Node(args=children, line_number=line) + return ast_internal_classes.Write_Stmt_Node(args=node.string, line_number=line) def program(self, node: FASTNode): children = self.create_children(node) - main_program = get_child(children, ast_internal_classes.Main_Program_Node) - function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)] - subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)] modules = [node for node in children if isinstance(node, ast_internal_classes.Module_Node)] - return ast_internal_classes.Program_Node(main_program=main_program, function_definitions=function_definitions, subroutine_definitions=subroutine_definitions, - modules=modules) + modules=modules, + module_declarations={}) def main_program(self, node: FASTNode): children = self.create_children(node) @@ -323,19 +686,35 @@ def program_stmt(self, node: FASTNode): return ast_internal_classes.Program_Stmt_Node(name=name, line_number=node.item.span) def subroutine_subprogram(self, node: FASTNode): + children = self.create_children(node) name = get_child(children, ast_internal_classes.Subroutine_Stmt_Node) specification_part = get_child(children, ast_internal_classes.Specification_Part_Node) execution_part = get_child(children, ast_internal_classes.Execution_Part_Node) + internal_subprogram_part = get_child(children, ast_internal_classes.Internal_Subprogram_Part_Node) return_type = ast_internal_classes.Void + + optional_args_count = 0 + if specification_part is not None: + for j in specification_part.specifications: + for k in j.vardecl: + if k.optional: + optional_args_count += 1 + mandatory_args_count = len(name.args) - optional_args_count + return ast_internal_classes.Subroutine_Subprogram_Node( name=name.name, args=name.args, + optional_args_count=optional_args_count, + mandatory_args_count=mandatory_args_count, specification_part=specification_part, execution_part=execution_part, + internal_subprogram_part=internal_subprogram_part, type=return_type, line_number=name.line_number, + elemental=name.elemental, + ) def end_program_stmt(self, node: FASTNode): @@ -344,16 +723,86 @@ def end_program_stmt(self, node: FASTNode): def only_list(self, node: FASTNode): children = self.create_children(node) names = [i for i in children if isinstance(i, ast_internal_classes.Name_Node)] - return ast_internal_classes.Only_List_Node(names=names) + renames = [i for i in children if isinstance(i, ast_internal_classes.Rename_Node)] + return ast_internal_classes.Only_List_Node(names=names, renames=renames) + + def prefix_stmt(self, prefix: Prefix): + if 'recursive' in prefix.string.lower(): + print("recursive found") + props: Dict[str, bool] = { + 'elemental': False, + 'recursive': False, + 'pure': False, + } + type = 'VOID' + for c in prefix.children: + if c.string.lower() in props.keys(): + props[c.string.lower()] = True + elif isinstance(c, Intrinsic_Type_Spec): + type = c.string + return ast_internal_classes.Prefix_Node(type=type, + elemental=props['elemental'], + recursive=props['recursive'], + pure=props['pure']) + + def function_subprogram(self, node: Function_Subprogram): + children = self.create_children(node) - def function_subprogram(self, node: FASTNode): - raise NotImplementedError("Function subprograms are not supported yet") + specification_part = get_child(children, ast_internal_classes.Specification_Part_Node) + execution_part = get_child(children, ast_internal_classes.Execution_Part_Node) + + name = get_child(children, ast_internal_classes.Function_Stmt_Node) + return_var: Name_Node = name.ret.name if name.ret else name.name + return_type: str = name.type + if name.type == 'VOID': + assert specification_part + var_decls: List[Var_Decl_Node] = [v + for c in specification_part.specifications if + isinstance(c, Decl_Stmt_Node) + for v in c.vardecl] + return_type = singular(v.type for v in var_decls if v.name == return_var.name) + + return ast_internal_classes.Function_Subprogram_Node( + name=name.name, + args=name.args, + ret=return_var, + specification_part=specification_part, + execution_part=execution_part, + type=return_type, + line_number=name.line_number, + elemental=name.elemental, + ) + + def function_stmt(self, node: Function_Stmt): + children = self.create_children(node) + name = get_child(children, ast_internal_classes.Name_Node) + args = get_child(children, ast_internal_classes.Arg_List_Node) + prefix = get_child(children, ast_internal_classes.Prefix_Node) + + type, elemental = (prefix.type, prefix.elemental) if prefix else ('VOID', False) + if prefix is not None and prefix.recursive: + print("recursive found " + name.name) + + ret = get_child(children, ast_internal_classes.Suffix_Node) + ret_args = args.args if args else [] + return ast_internal_classes.Function_Stmt_Node( + name=name, args=ret_args, line_number=node.item.span, ret=ret, elemental=elemental, type=ret) def subroutine_stmt(self, node: FASTNode): + # print(self.name_list) children = self.create_children(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Arg_List_Node) - return ast_internal_classes.Subroutine_Stmt_Node(name=name, args=args.args, line_number=node.item.span) + prefix = get_child(children, ast_internal_classes.Prefix_Node) + elemental = prefix.elemental if prefix else False + if prefix is not None and prefix.recursive: + print("recursive found " + name.name) + if args is None: + ret_args = [] + else: + ret_args = args.args + return ast_internal_classes.Subroutine_Stmt_Node(name=name, args=ret_args, line_number=node.item.span, + elemental=elemental) def ac_value_list(self, node: FASTNode): children = self.create_children(node) @@ -362,20 +811,26 @@ def ac_value_list(self, node: FASTNode): def power_expr(self, node: FASTNode): children = self.create_children(node) line = get_line(node) - #child 0 is the base, child 2 is the exponent - #child 1 is "**" + # child 0 is the base, child 2 is the exponent + # child 1 is "**" return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name="pow"), args=[children[0], children[2]], - line_number=line) + line_number=line, type="REAL", subroutine=False) def array_constructor(self, node: FASTNode): children = self.create_children(node) value_list = get_child(children, ast_internal_classes.Ac_Value_List_Node) - return ast_internal_classes.Array_Constructor_Node(value_list=value_list.value_list) + return ast_internal_classes.Array_Constructor_Node(value_list=value_list.value_list, type="VOID") def allocate_stmt(self, node: FASTNode): children = self.create_children(node) - return ast_internal_classes.Allocate_Stmt_Node(allocation_list=children[1]) + if isinstance(children[0], ast_internal_classes.Name_Node): + print(children[0].name) + if isinstance(children[0], ast_internal_classes.Data_Ref_Node): + print(children[0].parent_ref.name + "." + children[0].part_ref.name) + + line = get_line(node) + return ast_internal_classes.Allocate_Stmt_Node(name=children[0], allocation_list=children[1], line_number=line) def allocation_list(self, node: FASTNode): children = self.create_children(node) @@ -383,9 +838,13 @@ def allocation_list(self, node: FASTNode): def allocation(self, node: FASTNode): children = self.create_children(node) - name = get_child(children, ast_internal_classes.Name_Node) + name = children[0] + # if isinstance(children[0], ast_internal_classes.Name_Node): + # print(children[0].name) + # if isinstance(children[0], ast_internal_classes.Data_Ref_Node): + # print(children[0].parent_ref.name+"."+children[0].part_ref.name) shape = get_child(children, ast_internal_classes.Allocate_Shape_Spec_List) - return ast_internal_classes.Allocation_Node(name=name, shape=shape) + return ast_internal_classes.Allocation_Node(name=children[0], shape=shape) def allocate_shape_spec_list(self, node: FASTNode): children = self.create_children(node) @@ -399,21 +858,25 @@ def allocate_shape_spec(self, node: FASTNode): def structure_constructor(self, node: FASTNode): children = self.create_children(node) + line = get_line(node) name = get_child(children, ast_internal_classes.Type_Name_Node) args = get_child(children, ast_internal_classes.Component_Spec_List_Node) - return ast_internal_classes.Structure_Constructor_Node(name=name, args=args.args, type=None) + if args == None: + ret_args = [] + else: + ret_args = args.args + return ast_internal_classes.Structure_Constructor_Node(name=name, args=ret_args, type=None, line_number=line) def intrinsic_function_reference(self, node: FASTNode): children = self.create_children(node) line = get_line(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Arg_List_Node) - return self.intrinsic_handler.replace_function_reference(name, args, line) - def function_stmt(self, node: FASTNode): - raise NotImplementedError( - "Function statements are not supported yet - at least not if defined this way. Not encountered in code yet." - ) + if name is None: + return Name_Node(name="Error! " + node.children[0].string, type='VOID') + node = self.intrinsic_handler.replace_function_reference(name, args, line, self.symbols) + return node def end_subroutine_stmt(self, node: FASTNode): return node @@ -425,19 +888,67 @@ def parenthesis_expr(self, node: FASTNode): children = self.create_children(node) return ast_internal_classes.Parenthesis_Expr_Node(expr=children[1]) + def module_subprogram_part(self, node: FASTNode): + children = self.create_children(node) + function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)] + subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)] + return ast_internal_classes.Module_Subprogram_Part_Node(function_definitions=function_definitions, + subroutine_definitions=subroutine_definitions) + + def internal_subprogram_part(self, node: FASTNode): + children = self.create_children(node) + function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)] + subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)] + return ast_internal_classes.Internal_Subprogram_Part_Node(function_definitions=function_definitions, + subroutine_definitions=subroutine_definitions) + + def interface_block(self, node: FASTNode): + children = self.create_children(node) + + name = get_child(children, ast_internal_classes.Interface_Stmt_Node) + stmts = get_children(children, ast_internal_classes.Procedure_Statement_Node) + subroutines = [] + + for i in stmts: + + for child in i.namelists: + subroutines.extend(child.subroutines) + + # Ignore other implementations of an interface block with overloaded procedures + if name is None or len(subroutines) == 0: + return node + + return ast_internal_classes.Interface_Block_Node(name=name.name, subroutines=subroutines) + def module(self, node: FASTNode): children = self.create_children(node) + name = get_child(children, ast_internal_classes.Module_Stmt_Node) + module_subprogram_part = get_child(children, ast_internal_classes.Module_Subprogram_Part_Node) specification_part = get_child(children, ast_internal_classes.Specification_Part_Node) function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)] subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)] + + interface_blocks = {} + if specification_part is not None: + for iblock in specification_part.interface_blocks: + interface_blocks[iblock.name] = [x.name for x in iblock.subroutines] + + # add here to definitions + if module_subprogram_part is not None: + for i in module_subprogram_part.function_definitions: + function_definitions.append(i) + for i in module_subprogram_part.subroutine_definitions: + subroutine_definitions.append(i) + return ast_internal_classes.Module_Node( name=name.name, specification_part=specification_part, function_definitions=function_definitions, subroutine_definitions=subroutine_definitions, + interface_blocks=interface_blocks, line_number=name.line_number, ) @@ -453,7 +964,9 @@ def use_stmt(self, node: FASTNode): children = self.create_children(node) name = get_child(children, ast_internal_classes.Name_Node) only_list = get_child(children, ast_internal_classes.Only_List_Node) - return ast_internal_classes.Use_Stmt_Node(name=name.name, list=only_list.names) + if only_list is None: + return ast_internal_classes.Use_Stmt_Node(name=name.name, list=[], list_all=True) + return ast_internal_classes.Use_Stmt_Node(name=name.name, list=only_list.names, list_all=False) def implicit_part(self, node: FASTNode): return node @@ -472,7 +985,7 @@ def declaration_construct(self, node: FASTNode): return node def declaration_type_spec(self, node: FASTNode): - raise NotImplementedError("Declaration type spec is not supported yet") + # raise NotImplementedError("Declaration type spec is not supported yet") return node def assumed_shape_spec_list(self, node: FASTNode): @@ -485,56 +998,181 @@ def parse_shape_specification(self, dim: f03.Explicit_Shape_Spec, size: List[FAS # handle size definition if len(dim_expr) == 1: dim_expr = dim_expr[0] - #now to add the dimension to the size list after processing it if necessary + # now to add the dimension to the size list after processing it if necessary size.append(self.create_ast(dim_expr)) offset.append(1) + # Here we support arrays that have size declaration - with initial offset. elif len(dim_expr) == 2: # extract offets - for expr in dim_expr: - if not isinstance(expr, f03.Int_Literal_Constant): - raise TypeError("Array offsets must be constant expressions!") - offset.append(int(dim_expr[0].tostr())) + if isinstance(dim_expr[0], f03.Int_Literal_Constant): + # raise TypeError("Array offsets must be constant expressions!") + offset.append(int(dim_expr[0].tostr())) + else: + expr = self.create_ast(dim_expr[0]) + offset.append(expr) + + fortran_size = ast_internal_classes.BinOp_Node( + lval=self.create_ast(dim_expr[1]), + rval=self.create_ast(dim_expr[0]), + op="-", + type="INTEGER" + ) + size.append(ast_internal_classes.BinOp_Node( + lval=fortran_size, + rval=ast_internal_classes.Int_Literal_Node(value=str(1)), + op="+", + type="INTEGER") + ) + else: + raise TypeError("Array dimension must be at most two expressions") + + def assumed_array_shape(self, var, array_name: Optional[str], linenumber): + + # We do not know the array size. Thus, we insert symbols + # to mark its size + shape = get_children(var, "Assumed_Shape_Spec_List") - fortran_size = int(dim_expr[1].tostr()) - int(dim_expr[0].tostr()) + 1 - fortran_ast_size = f03.Int_Literal_Constant(str(fortran_size)) + if shape is None or len(shape) == 0: + shape = get_children(var, "Deferred_Shape_Spec_List") + + if shape is None: + return None, [] + + # this is based on structures observed in Fortran codes + # I don't know why the shape is an array + if len(shape) > 0: + dims_count = len(shape[0].items) + size = [] + vardecls = [] + + processed_array_names = [] + if array_name is not None: + if isinstance(array_name, str): + processed_array_names = [array_name] + else: + processed_array_names = [j.children[0].string for j in array_name] + else: + raise NotImplementedError("Assumed array shape not supported yet if array name missing") - size.append(self.create_ast(fortran_ast_size)) + sizes = [] + offsets = [] + for actual_array in processed_array_names: + + size = [] + offset = [] + for i in range(dims_count): + name = f'__f2dace_A_{actual_array}_d_{i}_s_{self.type_arbitrary_array_variable_count}' + offset_name = f'__f2dace_OA_{actual_array}_d_{i}_s_{self.type_arbitrary_array_variable_count}' + self.type_arbitrary_array_variable_count += 1 + self.placeholders[name] = [actual_array, i, self.type_arbitrary_array_variable_count] + self.placeholders_offsets[name] = [actual_array, i, self.type_arbitrary_array_variable_count] + + var = ast_internal_classes.Symbol_Decl_Node(name=name, + type='INTEGER', + alloc=False, + sizes=None, + offsets=None, + init=None, + kind=None, + line_number=linenumber) + var2 = ast_internal_classes.Symbol_Decl_Node(name=offset_name, + type='INTEGER', + alloc=False, + sizes=None, + offsets=None, + init=None, + kind=None, + line_number=linenumber) + size.append(ast_internal_classes.Name_Node(name=name)) + offset.append(ast_internal_classes.Name_Node(name=offset_name)) + + self.symbols[name] = None + vardecls.append(var) + vardecls.append(var2) + sizes.append(size) + offsets.append(offset) + + return sizes, vardecls, offsets else: - raise TypeError("Array dimension must be at most two expressions") + return None, [], None def type_declaration_stmt(self, node: FASTNode): - #decide if its a intrinsic variable type or a derived type + # decide if it's an intrinsic variable type or a derived type type_of_node = get_child(node, [f03.Intrinsic_Type_Spec, f03.Declaration_Type_Spec]) - + # if node.children[2].children[0].children[0].string.lower() =="BOUNDARY_MISSVAL".lower(): + # print("found boundary missval") if isinstance(type_of_node, f03.Intrinsic_Type_Spec): derived_type = False basetype = type_of_node.items[0] elif isinstance(type_of_node, f03.Declaration_Type_Spec): - derived_type = True - basetype = type_of_node.items[1].string + if type_of_node.items[0].lower() == "class": + basetype = "CLASS" + basetype = type_of_node.items[1].string + derived_type = True + else: + derived_type = True + basetype = type_of_node.items[1].string else: raise TypeError("Type of node must be either Intrinsic_Type_Spec or Declaration_Type_Spec") kind = None + size_later = False if len(type_of_node.items) >= 2: if type_of_node.items[1] is not None: if not derived_type: - kind = type_of_node.items[1].items[1].string - if self.symbols[kind] is not None: - if basetype == "REAL": - if self.symbols[kind].value == "8": - basetype = "REAL8" - elif basetype == "INTEGER": - if self.symbols[kind].value == "4": - basetype = "INTEGER" - else: - raise TypeError("Derived type not supported") + if basetype == "CLASS": + kind = "CLASS" + elif basetype == "CHARACTER": + kind = type_of_node.items[1].items[1].string.lower() + if kind == "*": + size_later = True else: - raise TypeError("Derived type not supported") - if derived_type: - raise TypeError("Derived type not supported") + if isinstance(type_of_node.items[1].items[1], f03.Int_Literal_Constant): + kind = type_of_node.items[1].items[1].string.lower() + if basetype == "REAL": + if kind == "8": + basetype = "REAL8" + else: + raise TypeError("Real kind not supported") + elif basetype == "INTEGER": + if kind == "4": + basetype = "INTEGER" + elif kind == "1": + # TODO: support for 1 byte integers /chars would be useful + basetype = "INTEGER" + + elif kind == "2": + # TODO: support for 2 byte integers would be useful + basetype = "INTEGER" + + elif kind == "8": + # TODO: support for 8 byte integers would be useful + basetype = "INTEGER" + else: + raise TypeError("Integer kind not supported") + else: + raise TypeError("Derived type not supported") + + else: + kind = type_of_node.items[1].items[1].string.lower() + if self.symbols[kind] is not None: + if basetype == "REAL": + while hasattr(self.symbols[kind], "name"): + kind = self.symbols[kind].name + if self.symbols[kind].value == "8": + basetype = "REAL8" + elif basetype == "INTEGER": + while hasattr(self.symbols[kind], "name"): + kind = self.symbols[kind].name + if self.symbols[kind].value == "4": + basetype = "INTEGER" + else: + raise TypeError("Derived type not supported") + + # if derived_type: + # raise TypeError("Derived type not supported") if not derived_type: testtype = self.types[basetype] else: @@ -544,26 +1182,43 @@ def type_declaration_stmt(self, node: FASTNode): # get the names of the variables being defined names_list = get_child(node, ["Entity_Decl_List", "Component_Decl_List"]) - #get the names out of the name list + # get the names out of the name list names = get_children(names_list, [f03.Entity_Decl, f03.Component_Decl]) - #get the attributes of the variables being defined + # get the attributes of the variables being defined # alloc relates to whether it is statically (False) or dynamically (True) allocated - # parameter means its a constant, so we should transform it into a symbol + # parameter means it's a constant, so we should transform it into a symbol attributes = get_children(node, "Attr_Spec_List") + comp_attributes = get_children(node, "Component_Attr_Spec_List") + if len(attributes) != 0 and len(comp_attributes) != 0: + raise TypeError("Attributes must be either in Attr_Spec_List or Component_Attr_Spec_List not both") alloc = False symbol = False + optional = False attr_size = None attr_offset = None - for i in attributes: + assumed_vardecls = [] + for i in attributes + comp_attributes: + if i.string.lower() == "allocatable": alloc = True if i.string.lower() == "parameter": symbol = True + if i.string.lower() == "pointer": + alloc = True + if i.string.lower() == "optional": + optional = True if isinstance(i, f08.Attr_Spec_List): + specification = get_children(i, "Attr_Spec") + for spec in specification: + if spec.string.lower() == "optional": + optional = True + if spec.string.lower() == "allocatable": + alloc = True + dimension_spec = get_children(i, "Dimension_Attr_Spec") if len(dimension_spec) == 0: continue @@ -571,68 +1226,138 @@ def type_declaration_stmt(self, node: FASTNode): attr_size = [] attr_offset = [] sizes = get_child(dimension_spec[0], ["Explicit_Shape_Spec_List"]) - - for shape_spec in get_children(sizes, [f03.Explicit_Shape_Spec]): - self.parse_shape_specification(shape_spec, attr_size, attr_offset) - vardecls = [] + if sizes is not None: + for shape_spec in get_children(sizes, [f03.Explicit_Shape_Spec]): + self.parse_shape_specification(shape_spec, attr_size, attr_offset) + # we expect a list of lists, where each element correspond to list of symbols for each array name + attr_size = [attr_size] * len(names) + attr_offset = [attr_offset] * len(names) + else: + attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape(dimension_spec[0], names, + node.item.span) + + if attr_size is None: + raise RuntimeError("Couldn't parse the dimension attribute specification!") + + if isinstance(i, f08.Component_Attr_Spec_List): + + specification = get_children(i, "Component_Attr_Spec") + for spec in specification: + if spec.string.lower() == "optional": + optional = True + if spec.string.lower() == "allocatable": + alloc = True + + dimension_spec = get_children(i, "Dimension_Component_Attr_Spec") + if len(dimension_spec) == 0: + continue + + attr_size = [] + attr_offset = [] + sizes = get_child(dimension_spec[0], ["Explicit_Shape_Spec_List"]) + # if sizes is None: + # sizes = get_child(dimension_spec[0], ["Deferred_Shape_Spec_List"]) + + if sizes is not None: + for shape_spec in get_children(sizes, [f03.Explicit_Shape_Spec]): + self.parse_shape_specification(shape_spec, attr_size, attr_offset) + # we expect a list of lists, where each element correspond to list of symbols for each array name + attr_size = [attr_size] * len(names) + attr_offset = [attr_offset] * len(names) + else: + attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape(dimension_spec[0], names, + node.item.span) + if attr_size is None: + raise RuntimeError("Couldn't parse the dimension attribute specification!") + + vardecls = [*assumed_vardecls] - for var in names: - #first handle dimensions + for idx, var in enumerate(names): + # print(self.name_list) + # first handle dimensions size = None offset = None var_components = self.create_children(var) array_sizes = get_children(var, "Explicit_Shape_Spec_List") actual_name = get_child(var_components, ast_internal_classes.Name_Node) + # if actual_name.name not in self.name_list: + # return if len(array_sizes) == 1: array_sizes = array_sizes[0] size = [] offset = [] for dim in array_sizes.children: - #sanity check + # sanity check if isinstance(dim, f03.Explicit_Shape_Spec): self.parse_shape_specification(dim, size, offset) - #handle initializiation + + # handle initializiation init = None initialization = get_children(var, f03.Initialization) if len(initialization) == 1: initialization = initialization[0] - #if there is an initialization, the actual expression is in the second child, with the first being the equals sign + # if there is an initialization, the actual expression is in the second child, with the first being the equals sign if len(initialization.children) < 2: raise ValueError("Initialization must have an expression") raw_init = initialization.children[1] init = self.create_ast(raw_init) - + else: + comp_init = get_children(var, "Component_Initialization") + if len(comp_init) == 1: + raw_init = comp_init[0].children[1] + init = self.create_ast(raw_init) + # if size_later: + # size.append(len(init)) + if testtype != "INTEGER": symbol = False if symbol == False: if attr_size is None: + + if size is None: + + size, assumed_vardecls, offset = self.assumed_array_shape(var, actual_name.name, node.item.span) + if size is None: + offset = None + else: + # only one array + size = size[0] + offset = offset[0] + # offset = [1] * len(size) + vardecls.extend(assumed_vardecls) + vardecls.append( ast_internal_classes.Var_Decl_Node(name=actual_name.name, - type=testtype, - alloc=alloc, - sizes=size, - offsets=offset, - kind=kind, - line_number=node.item.span)) + type=testtype, + alloc=alloc, + sizes=size, + offsets=offset, + kind=kind, + init=init, + optional=optional, + line_number=node.item.span)) else: vardecls.append( ast_internal_classes.Var_Decl_Node(name=actual_name.name, - type=testtype, - alloc=alloc, - sizes=attr_size, - offsets=attr_offset, - kind=kind, - line_number=node.item.span)) + type=testtype, + alloc=alloc, + sizes=attr_size[idx], + offsets=attr_offset[idx], + kind=kind, + init=init, + optional=optional, + line_number=node.item.span)) else: if size is None and attr_size is None: self.symbols[actual_name.name] = init vardecls.append( ast_internal_classes.Symbol_Decl_Node(name=actual_name.name, type=testtype, + sizes=None, alloc=alloc, init=init, - line_number=node.item.span)) + optional=optional)) elif attr_size is not None: vardecls.append( ast_internal_classes.Symbol_Array_Decl_Node(name=actual_name.name, @@ -642,6 +1367,7 @@ def type_declaration_stmt(self, node: FASTNode): offsets=attr_offset, kind=kind, init=init, + optional=optional, line_number=node.item.span)) else: vardecls.append( @@ -652,8 +1378,9 @@ def type_declaration_stmt(self, node: FASTNode): offsets=offset, kind=kind, init=init, + optional=optional, line_number=node.item.span)) - return ast_internal_classes.Decl_Stmt_Node(vardecl=vardecls, line_number=node.item.span) + return ast_internal_classes.Decl_Stmt_Node(vardecl=vardecls) def entity_decl(self, node: FASTNode): raise NotImplementedError("Entity decl is not supported yet") @@ -679,7 +1406,8 @@ def intent_spec(self, node: FASTNode): return node def access_spec(self, node: FASTNode): - raise NotImplementedError("Access spec is not supported yet") + print("access spec. Fix me") + # raise NotImplementedError("Access spec is not supported yet") return node def allocatable_stmt(self, node: FASTNode): @@ -691,7 +1419,8 @@ def asynchronous_stmt(self, node: FASTNode): return node def bind_stmt(self, node: FASTNode): - raise NotImplementedError("Bind stmt is not supported yet") + print("bind stmt. Fix me") + # raise NotImplementedError("Bind stmt is not supported yet") return node def common_stmt(self, node: FASTNode): @@ -699,7 +1428,8 @@ def common_stmt(self, node: FASTNode): return node def data_stmt(self, node: FASTNode): - raise NotImplementedError("Data stmt is not supported yet") + print("data stmt! fix me!") + # raise NotImplementedError("Data stmt is not supported yet") return node def dimension_stmt(self, node: FASTNode): @@ -723,6 +1453,7 @@ def parameter_stmt(self, node: FASTNode): return node def pointer_stmt(self, node: FASTNode): + raise NotImplementedError("Pointer stmt is not supported yet") return node def protected_stmt(self, node: FASTNode): @@ -741,7 +1472,7 @@ def volatile_stmt(self, node: FASTNode): return node def execution_part(self, node: FASTNode): - children = self.create_children(node) + children = [child for child in self.create_children(node) if child is not None] return ast_internal_classes.Execution_Part_Node(execution=children) def execution_part_construct(self, node: FASTNode): @@ -753,42 +1484,76 @@ def action_stmt(self, node: FASTNode): def level_2_expr(self, node: FASTNode): children = self.create_children(node) line = get_line(node) + if children[1] == "==": + type = "LOGICAL" + else: + type = "VOID" + if hasattr(children[0], "type"): + type = children[0].type if len(children) == 3: - return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line) + return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, + type=type) else: - return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line) + return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line, + type=children[1].type) - def assignment_stmt(self, node: FASTNode): + def assignment_stmt(self, node: Assignment_Stmt): children = self.create_children(node) line = get_line(node) if len(children) == 3: - return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line) + return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, + type=children[0].type) else: - return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line) + return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line, + type=children[1].type) def pointer_assignment_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + line = get_line(node) + return ast_internal_classes.Pointer_Assignment_Stmt_Node(name_pointer=children[0], + name_target=children[2], + line_number=line) def where_stmt(self, node: FASTNode): return node - def forall_stmt(self, node: FASTNode): - return node - def where_construct(self, node: FASTNode): - return node + children = self.create_children(node) + line = children[0].line_number + cond = children[0] + body = children[1] + current = 2 + body_else = None + elifs_cond = [] + elifs_body = [] + while children[current] is not None: + if isinstance(children[current], str) and children[current].lower() == "elsewhere": + body_else = children[current + 1] + current += 2 + else: + elifs_cond.append(children[current]) + elifs_body.append(children[current + 1]) + current += 2 + return ast_internal_classes.Where_Construct_Node(body=body, cond=cond, body_else=body_else, + elifs_cond=elifs_cond, elifs_body=elifs_cond, line_number=line) def where_construct_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + return children[0] def masked_elsewhere_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + return children[0] def elsewhere_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + return children[0] def end_where_stmt(self, node: FASTNode): + return None + + def forall_stmt(self, node: FASTNode): return node def forall_construct(self, node: FASTNode): @@ -814,6 +1579,8 @@ def if_stmt(self, node: FASTNode): line = get_line(node) cond = children[0] body = children[1:] + # !THIS IS HACK + body = [i for i in body if i is not None] return ast_internal_classes.If_Stmt_Node(cond=cond, body=ast_internal_classes.Execution_Part_Node(execution=body), body_else=ast_internal_classes.Execution_Part_Node(execution=[]), @@ -831,6 +1598,8 @@ def if_construct(self, node: FASTNode): toplevelIf = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line) currentIf = toplevelIf for i in children[1:-1]: + if i is None: + continue if isinstance(i, ast_internal_classes.Else_If_Stmt_Node): newif = ast_internal_classes.If_Stmt_Node(cond=i.cond, line_number=i.line_number) currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) @@ -844,6 +1613,7 @@ def if_construct(self, node: FASTNode): if else_mode: body_else.append(i) else: + body.append(i) currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else) @@ -868,13 +1638,95 @@ def end_if_stmt(self, node: FASTNode): return node def case_construct(self, node: FASTNode): - return node + children = self.create_children(node) + cond_start = children[0] + cond_end = children[1] + body = [] + body_else = [] + else_mode = False + line = get_line(node) + if line is None: + line = "Unknown:TODO" + cond = ast_internal_classes.BinOp_Node(op=cond_end.op[0], lval=cond_start, rval=cond_end.cond[0], + line_number=line) + for j in range(1, len(cond_end.op)): + cond_add = ast_internal_classes.BinOp_Node(op=cond_end.op[j], lval=cond_start, rval=cond_end.cond[j], + line_number=line) + cond = ast_internal_classes.BinOp_Node(op=".OR.", lval=cond, rval=cond_add, line_number=line) + + toplevelIf = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line) + currentIf = toplevelIf + for i in children[2:-1]: + if i is None: + continue + if isinstance(i, ast_internal_classes.Case_Cond_Node): + cond = ast_internal_classes.BinOp_Node(op=i.op[0], lval=cond_start, rval=i.cond[0], line_number=line) + for j in range(1, len(i.op)): + cond_add = ast_internal_classes.BinOp_Node(op=i.op[j], lval=cond_start, rval=i.cond[j], + line_number=line) + cond = ast_internal_classes.BinOp_Node(op=".OR.", lval=cond, rval=cond_add, line_number=line) + + newif = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line) + currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) + currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=[newif]) + currentIf = newif + body = [] + continue + if isinstance(i, str) and i == "__default__": + else_mode = True + continue + if else_mode: + body_else.append(i) + else: + + body.append(i) + currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) + currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else) + return toplevelIf def select_case_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + if len(children) != 1: + raise ValueError("CASE should have only 1 child") + return children[0] def case_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + children = [i for i in children if i is not None] + if len(children) == 1: + return children[0] + elif len(children) == 0: + return "__default__" + else: + raise ValueError("Can't parse case statement") + + def case_selector(self, node: FASTNode): + children = self.create_children(node) + if len(children) == 1: + if children[0] is None: + return None + returns = ast_internal_classes.Case_Cond_Node(op=[], cond=[]) + + for i in children[0]: + returns.op.append(i[0]) + returns.cond.append(i[1]) + return returns + else: + raise ValueError("Can't parse case selector") + + def case_value_range_list(self, node: FASTNode): + children = self.create_children(node) + if len(children) == 1: + return [[".EQ.", children[0]]] + if len(children) == 2: + return [[".EQ.", children[0]], [".EQ.", children[1]]] + else: + retlist = [] + for i in children: + retlist.append([".EQ.", i]) + return retlist + # else: + # raise ValueError("Can't parse case range list") def end_select_stmt(self, node: FASTNode): return node @@ -888,6 +1740,12 @@ def label_do_stmt(self, node: FASTNode): def nonlabel_do_stmt(self, node: FASTNode): children = self.create_children(node) loop_control = get_child(children, ast_internal_classes.Loop_Control_Node) + if loop_control is None: + if node.string == "DO": + return ast_internal_classes.While_True_Control(name=node.item.name, line_number=node.item.span) + else: + while_control = get_child(children, ast_internal_classes.While_Control) + return ast_internal_classes.While_Control(cond=while_control.cond, line_number=node.item.span) return ast_internal_classes.Nonlabel_Do_Stmt_Node(iter=loop_control.iter, cond=loop_control.cond, init=loop_control.init, @@ -896,23 +1754,44 @@ def nonlabel_do_stmt(self, node: FASTNode): def end_do_stmt(self, node: FASTNode): return node - def interface_block(self, node: FASTNode): - return node - def interface_stmt(self, node: FASTNode): - return node + children = self.create_children(node) + name = get_child(children, ast_internal_classes.Name_Node) + if name is not None: + return ast_internal_classes.Interface_Stmt_Node(name=name.name) + else: + return node def end_interface_stmt(self, node: FASTNode): return node + def procedure_name_list(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Procedure_Name_List_Node(subroutines=children) + + def procedure_stmt(self, node: FASTNode): + # ignore the procedure statement - just return the name list + children = self.create_children(node) + namelists = get_children(children, ast_internal_classes.Procedure_Name_List_Node) + if namelists is not None: + return ast_internal_classes.Procedure_Statement_Node(namelists=namelists) + else: + return node + def generic_spec(self, node: FASTNode): + children = self.create_children(node) return node def procedure_declaration_stmt(self, node: FASTNode): return node + def specific_binding(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Specific_Binding_Node(name=children[3], args=children[0:2] + [children[4]]) + def type_bound_procedure_part(self, node: FASTNode): - return node + children = self.create_children(node) + return ast_internal_classes.Bound_Procedures_Node(procedures=children[1:]) def contains_stmt(self, node: FASTNode): return node @@ -920,14 +1799,33 @@ def contains_stmt(self, node: FASTNode): def call_stmt(self, node: FASTNode): children = self.create_children(node) name = get_child(children, ast_internal_classes.Name_Node) + arg_addition = None + if name is None: + proc_ref = get_child(children, ast_internal_classes.Procedure_Separator_Node) + name = proc_ref.part_ref + arg_addition = proc_ref.parent_ref + args = get_child(children, ast_internal_classes.Arg_List_Node) - return ast_internal_classes.Call_Expr_Node(name=name, args=args.args, type=None, line_number=node.item.span) + if args is None: + ret_args = [] + else: + ret_args = args.args + if arg_addition is not None: + ret_args.insert(0, arg_addition) + line_number = get_line(node) + # if node.item is None: + # line_number = 42 + # else: + # line_number = node.item.span + return ast_internal_classes.Call_Expr_Node(name=name, args=ret_args, type="VOID", subroutine=True, + line_number=line_number) def return_stmt(self, node: FASTNode): - return node + return None def stop_stmt(self, node: FASTNode): - return node + return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name="__dace_exit"), args=[], + type="VOID", subroutine=False, line_number=node.item.span) def dummy_arg_list(self, node: FASTNode): children = self.create_children(node) @@ -945,15 +1843,14 @@ def part_ref(self, node: FASTNode): line = get_line(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Section_Subscript_List_Node) - return ast_internal_classes.Call_Expr_Node( - name=name, - args=args.list, - line=line, - ) + return ast_internal_classes.Array_Subscript_Node(name=name, type="VOID", indices=args.list, + line_number=line) def loop_control(self, node: FASTNode): children = self.create_children(node) - #Structure of loop control is: + # Structure of loop control is: + if children[1] is None: + return ast_internal_classes.While_Control(cond=children[0], line_number=node.parent.item.span) # child[1]. Loop control variable # child[1][0] Loop start # child[1][1] Loop end @@ -964,23 +1861,40 @@ def loop_control(self, node: FASTNode): loop_step = children[1][1][2] else: loop_step = ast_internal_classes.Int_Literal_Node(value="1") - init_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="=", rval=loop_start) + init_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="=", rval=loop_start, type="INTEGER") if isinstance(loop_step, ast_internal_classes.UnOp_Node): if loop_step.op == "-": - cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op=">=", rval=loop_end) + cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op=">=", rval=loop_end, + type="INTEGER") else: - cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="<=", rval=loop_end) + cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="<=", rval=loop_end, type="INTEGER") iter_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="=", rval=ast_internal_classes.BinOp_Node(lval=iteration_variable, op="+", - rval=loop_step)) + rval=loop_step, + type="INTEGER"), + type="INTEGER") return ast_internal_classes.Loop_Control_Node(init=init_expr, cond=cond_expr, iter=iter_expr) def block_nonlabel_do_construct(self, node: FASTNode): children = self.create_children(node) do = get_child(children, ast_internal_classes.Nonlabel_Do_Stmt_Node) body = children[1:-1] + body = [i for i in body if i is not None] + if do is None: + while_true_header = get_child(children, ast_internal_classes.While_True_Control) + if while_true_header is not None: + return ast_internal_classes.While_Stmt_Node(name=while_true_header.name, + body=ast_internal_classes.Execution_Part_Node( + execution=body), + line_number=while_true_header.line_number) + while_header = get_child(children, ast_internal_classes.While_Control) + if while_header is not None: + return ast_internal_classes.While_Stmt_Node(cond=while_header.cond, + body=ast_internal_classes.Execution_Part_Node( + execution=body), + line_number=while_header.line_number) return ast_internal_classes.For_Stmt_Node(init=do.init, cond=do.cond, iter=do.iter, @@ -998,31 +1912,45 @@ def section_subscript_list(self, node: FASTNode): return ast_internal_classes.Section_Subscript_List_Node(list=children) def specification_part(self, node: FASTNode): - #TODO this can be refactored to consider more fortran declaration options. Currently limited to what is encountered in code. + + # TODO this can be refactored to consider more fortran declaration options. Currently limited to what is encountered in code. others = [self.create_ast(i) for i in node.children if not isinstance(i, f08.Type_Declaration_Stmt)] decls = [self.create_ast(i) for i in node.children if isinstance(i, f08.Type_Declaration_Stmt)] - + enums = [self.create_ast(i) for i in node.children if isinstance(i, f03.Enum_Def)] + # decls = list(filter(lambda x: x is not None, decls)) uses = [self.create_ast(i) for i in node.children if isinstance(i, f03.Use_Stmt)] tmp = [self.create_ast(i) for i in node.children] - typedecls = [i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node)] + typedecls = [ + i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node) + or isinstance(i, ast_internal_classes.Derived_Type_Def_Node) + ] symbols = [] + iblocks = [] for i in others: if isinstance(i, list): symbols.extend(j for j in i if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node)) if isinstance(i, ast_internal_classes.Decl_Stmt_Node): symbols.extend(j for j in i.vardecl if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node)) + if isinstance(i, ast_internal_classes.Interface_Block_Node): + iblocks.append(i) + for i in decls: if isinstance(i, list): symbols.extend(j for j in i if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node)) + symbols.extend(j for j in i if isinstance(j, ast_internal_classes.Symbol_Decl_Node)) if isinstance(i, ast_internal_classes.Decl_Stmt_Node): symbols.extend(j for j in i.vardecl if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node)) + symbols.extend(j for j in i.vardecl if isinstance(j, ast_internal_classes.Symbol_Decl_Node)) names_filtered = [] for j in symbols: for i in decls: names_filtered.extend(ii.name for ii in i.vardecl if j.name == ii.name) decl_filtered = [] + for i in decls: + if i is None: + continue # NOTE: Assignment/named expressions (walrus operator) works with Python 3.8 and later. # if vardecl_filtered := [ii for ii in i.vardecl if ii.name not in names_filtered]: vardecl_filtered = [ii for ii in i.vardecl if ii.name not in names_filtered] @@ -1030,8 +1958,10 @@ def specification_part(self, node: FASTNode): decl_filtered.append(ast_internal_classes.Decl_Stmt_Node(vardecl=vardecl_filtered)) return ast_internal_classes.Specification_Part_Node(specifications=decl_filtered, symbols=symbols, + interface_blocks=iblocks, uses=uses, - typedecls=typedecls) + typedecls=typedecls, + enums=enums) def intrinsic_type_spec(self, node: FASTNode): return node @@ -1039,18 +1969,43 @@ def intrinsic_type_spec(self, node: FASTNode): def entity_decl_list(self, node: FASTNode): return node - def int_literal_constant(self, node: FASTNode): - return ast_internal_classes.Int_Literal_Node(value=node.string) + def int_literal_constant(self, node: Union[Int_Literal_Constant, Signed_Int_Literal_Constant]): + value = node.string + if value.find("_") != -1: + x = value.split("_") + value = x[0] + return ast_internal_classes.Int_Literal_Node(value=value, type="INTEGER") - def logical_literal_constant(self, node: FASTNode): + def hex_constant(self, node: Hex_Constant): + return ast_internal_classes.Int_Literal_Node(value=str(int(node.string[2:-1], 16)), type="INTEGER") + + def logical_literal_constant(self, node: Logical_Literal_Constant): if node.string in [".TRUE.", ".true.", ".True."]: return ast_internal_classes.Bool_Literal_Node(value="True") if node.string in [".FALSE.", ".false.", ".False."]: return ast_internal_classes.Bool_Literal_Node(value="False") raise ValueError("Unknown logical literal constant") - def real_literal_constant(self, node: FASTNode): - return ast_internal_classes.Real_Literal_Node(value=node.string) + def real_literal_constant(self, node: Union[Real_Literal_Constant, Signed_Real_Literal_Constant]): + value = node.children[0].lower() + if len(node.children) == 2 and node.children[1] is not None and node.children[1].lower() == "wp": + return ast_internal_classes.Double_Literal_Node(value=value, type="DOUBLE") + if value.find("_") != -1: + x = value.split("_") + value = x[0] + print(x[1]) + if x[1] == "wp": + return ast_internal_classes.Double_Literal_Node(value=value, type="DOUBLE") + return ast_internal_classes.Real_Literal_Node(value=value, type="REAL") + + def char_literal_constant(self, node: FASTNode): + return ast_internal_classes.Char_Literal_Node(value=node.string, type="CHAR") + + def actual_arg_spec(self, node: FASTNode): + children = self.create_children(node) + if len(children) != 2: + raise ValueError("Actual arg spec must have two children") + return ast_internal_classes.Actual_Arg_Spec_Node(arg_name=children[0], arg=children[1], type="VOID") def actual_arg_spec_list(self, node: FASTNode): children = self.create_children(node) @@ -1060,10 +2015,14 @@ def initialization(self, node: FASTNode): return node def name(self, node: FASTNode): - return ast_internal_classes.Name_Node(name=node.string) + return ast_internal_classes.Name_Node(name=node.string.lower(), type="VOID") + + def rename(self, node: FASTNode): + return ast_internal_classes.Rename_Node(oldname=node.children[2].string.lower(), + newname=node.children[1].string.lower()) def type_name(self, node: FASTNode): - return ast_internal_classes.Type_Name_Node(name=node.string) + return ast_internal_classes.Type_Name_Node(name=node.string.lower()) def tuple_node(self, node: FASTNode): return node diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py new file mode 100644 index 0000000000..60ab6c9034 --- /dev/null +++ b/dace/frontend/fortran/ast_desugaring.py @@ -0,0 +1,2098 @@ +import math +import operator +import re +import sys +from dataclasses import dataclass +from typing import Union, Tuple, Dict, Optional, List, Iterable, Set, Type, Any + +import networkx as nx +import numpy as np +from fparser.api import get_reader +from fparser.two.Fortran2003 import Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, \ + Component_Decl, Entity_Decl, Specific_Binding, Generic_Binding, Interface_Stmt, Main_Program, Subroutine_Subprogram, \ + Function_Subprogram, Name, Program, Use_Stmt, Rename, Part_Ref, Data_Ref, Intrinsic_Type_Spec, \ + Declaration_Type_Spec, Initialization, Intrinsic_Function_Reference, Int_Literal_Constant, Length_Selector, \ + Kind_Selector, Derived_Type_Def, Type_Name, Module, Function_Reference, Structure_Constructor, Call_Stmt, \ + Intrinsic_Name, Access_Stmt, Enum_Def, Expr, Enumerator, Real_Literal_Constant, Signed_Real_Literal_Constant, \ + Signed_Int_Literal_Constant, Char_Literal_Constant, Logical_Literal_Constant, Section_Subscript, Actual_Arg_Spec, \ + Level_2_Unary_Expr, And_Operand, Parenthesis, Level_2_Expr, Level_3_Expr, Array_Constructor, Execution_Part, \ + Specification_Part, Interface_Block, Association, Procedure_Designator, Type_Bound_Procedure_Part, \ + Associate_Construct, Subscript_Triplet, End_Function_Stmt, End_Subroutine_Stmt, Module_Subprogram_Part, \ + Enumerator_List, Actual_Arg_Spec_List, Only_List, Section_Subscript_List, Char_Selector, Data_Pointer_Object, \ + Explicit_Shape_Spec, Component_Initialization, Subroutine_Body, Function_Body, If_Then_Stmt, Else_If_Stmt, \ + Else_Stmt, If_Construct, Level_4_Expr, Level_5_Expr, Hex_Constant, Add_Operand, Mult_Operand, Assignment_Stmt, \ + Loop_Control +from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt +from fparser.two.utils import Base, walk, BinaryOpBase, UnaryOpBase + +from dace.frontend.fortran.ast_utils import singular, children_of_type, atmost_one + +ENTRY_POINT_OBJECT_TYPES = Union[Main_Program, Subroutine_Subprogram, Function_Subprogram] +SCOPE_OBJECT_TYPES = Union[ + Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, Interface_Block, + Subroutine_Body, Function_Body] +NAMED_STMTS_OF_INTEREST_TYPES = Union[ + Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, Component_Decl, Entity_Decl, + Specific_Binding, Generic_Binding, Interface_Stmt] +SPEC = Tuple[str, ...] +SPEC_TABLE = Dict[SPEC, NAMED_STMTS_OF_INTEREST_TYPES] + + +class TYPE_SPEC: + NO_ATTRS = '' + + def __init__(self, + spec: Union[str, SPEC], + attrs: str = NO_ATTRS): + if isinstance(spec, str): + spec = (spec,) + self.spec: SPEC = spec + self.shape: Tuple[str, ...] = self._parse_shape(attrs) + self.optional: bool = 'OPTIONAL' in attrs + self.inp: bool = 'INTENT(IN)' in attrs or 'INTENT(INOUT)' in attrs + self.out: bool = 'INTENT(OUT)' in attrs or 'INTENT(INOUT)' in attrs + self.const: bool = 'PARAMETER' in attrs + self.keyword: Optional[str] = None + + @staticmethod + def _parse_shape(attrs: str) -> Tuple[str, ...]: + if 'DIMENSION' not in attrs: + return tuple() + dims: re.Match = re.search(r'DIMENSION\(([^)]*)\)', attrs, re.IGNORECASE) + assert dims + dims: str = dims.group(1) + return tuple(p.strip().lower() for p in dims.split(',')) + + def __repr__(self): + attrs = [] + if self.shape: + attrs.append(f"shape={self.shape}") + if self.optional: + attrs.append("optional") + if not attrs: + return f"{self.spec}" + return f"{self.spec}[{' | '.join(attrs)}]" + + +def find_name_of_stmt(node: NAMED_STMTS_OF_INTEREST_TYPES) -> Optional[str]: + """Find the name of the statement if it has one. For anonymous blocks, return `None`.""" + if isinstance(node, Specific_Binding): + # Ref: https://github.com/stfc/fparser/blob/8c870f84edbf1a24dfbc886e2f7226d1b158d50b/src/fparser/two/Fortran2003.py#L2504 + _, _, _, bname, _ = node.children + name = bname + elif isinstance(node, Interface_Stmt): + name, = node.children + else: + # TODO: Test out other type specific ways of finding names. + name = singular(children_of_type(node, Name)) + if name: + assert isinstance(name, Name) + name = name.string + return name + + +def find_name_of_node(node: Base) -> Optional[str]: + """Find the name of the general node if it has one. For anonymous blocks, return `None`.""" + if isinstance(node, NAMED_STMTS_OF_INTEREST_TYPES): + return find_name_of_stmt(node) + stmt = atmost_one(children_of_type(node, NAMED_STMTS_OF_INTEREST_TYPES)) + if not stmt: + return None + return find_name_of_stmt(stmt) + + +def find_scope_ancestor(node: Base) -> Optional[SCOPE_OBJECT_TYPES]: + anc = node.parent + while anc and not isinstance(anc, SCOPE_OBJECT_TYPES): + anc = anc.parent + return anc + + +def find_named_ancestor(node: Base) -> Optional[NAMED_STMTS_OF_INTEREST_TYPES]: + anc = find_scope_ancestor(node) + if not anc: + return None + return atmost_one(children_of_type(anc, NAMED_STMTS_OF_INTEREST_TYPES)) + + +def lineage(anc: Base, des: Base) -> Optional[Tuple[Base, ...]]: + if anc == des: + return (anc,) + if not des.parent: + return None + lin = lineage(anc, des.parent) + if not lin: + return None + return lin + (des,) + + +def search_scope_spec(node: Base) -> Optional[SPEC]: + scope = find_scope_ancestor(node) + if not scope: + return None + lin = lineage(scope, node) + assert lin + par = node.parent + # TODO: How many other such cases can there be? + if (isinstance(scope, Derived_Type_Def) + and any( + isinstance(x, (Explicit_Shape_Spec, Component_Initialization, Kind_Selector, Char_Selector)) + for x in lin)): + # We're using `node` to describe a shape, an initialization etc. inside a type def. So, `node`` must have been + # defined earlier. + return search_scope_spec(scope) + elif isinstance(par, Actual_Arg_Spec): + kw, _ = par.children + if kw == node: + # We're describing a keyword, which is not really an identifiable object. + return None + stmt = singular(children_of_type(scope, NAMED_STMTS_OF_INTEREST_TYPES)) + if not find_name_of_stmt(stmt): + # If this is an anonymous object, the scope has to be outside. + return search_scope_spec(scope.parent) + return ident_spec(stmt) + + +def find_scope_spec(node: Base) -> SPEC: + spec = search_scope_spec(node) + assert spec, f"cannot find scope for: ```\n{node.tofortran()}```" + return spec + + +def ident_spec(node: NAMED_STMTS_OF_INTEREST_TYPES) -> SPEC: + def _ident_spec(_node: NAMED_STMTS_OF_INTEREST_TYPES) -> SPEC: + """ + Constuct a list of identifier strings that can uniquely determine it through the entire AST. + """ + ident_base = (find_name_of_stmt(_node),) + # Find the next named ancestor. + anc = find_named_ancestor(_node.parent) + if not anc: + return ident_base + assert isinstance(anc, NAMED_STMTS_OF_INTEREST_TYPES) + return _ident_spec(anc) + ident_base + + spec = _ident_spec(node) + # The last part of the spec cannot be nothing, because we cannot refer to the anonymous blocks. + assert spec and spec[-1] + # For the rest, the anonymous blocks puts their content onto their parents. + spec = tuple(c for c in spec if c) + return spec + + +def search_local_alias_spec(node: Name) -> Optional[SPEC]: + name, par = node.string, node.parent + scope_spec = search_scope_spec(node) + if scope_spec is None: + return None + if isinstance(par, (Part_Ref, Data_Ref, Data_Pointer_Object)): + # If we are in a data-ref then we need to get to the root. + while isinstance(par.parent, Data_Ref): + par = par.parent + while isinstance(par, Data_Ref): + # TODO: Add ref. + par, _ = par.children[0], par.children[1:] + if isinstance(par, (Part_Ref, Data_Pointer_Object)): + # TODO: Add ref. + par, _ = par.children[0], par.children[1:] + assert isinstance(par, Name) + if par != node: + # Components do not really have a local alias. + return None + elif isinstance(par, Kind_Selector): + # Reserved name in this context. + if name.upper() == 'KIND': + return None + elif isinstance(par, Char_Selector): + # Reserved name in this context. + if name.upper() in {'KIND', 'LEN'}: + return None + elif isinstance(par, Actual_Arg_Spec): + # Keywords cannot be aliased. + kw, _ = par.children + if kw == node: + return None + return scope_spec + (name,) + + +def search_real_local_alias_spec_from_spec(loc: SPEC, alias_map: SPEC_TABLE) -> Optional[SPEC]: + while len(loc) > 1 and loc not in alias_map: + # The name is not immediately available in the current scope, but may be it is in the parent's scope. + loc = loc[:-2] + (loc[-1],) + return loc if loc in alias_map else None + + +def search_real_local_alias_spec(node: Name, alias_map: SPEC_TABLE) -> Optional[SPEC]: + loc = search_local_alias_spec(node) + if not loc: + return None + return search_real_local_alias_spec_from_spec(loc, alias_map) + + +def identifier_specs(ast: Program) -> SPEC_TABLE: + """ + Maps each identifier of interest in `ast` to its associated node that defines it. + """ + ident_map: SPEC_TABLE = {} + for stmt in walk(ast, NAMED_STMTS_OF_INTEREST_TYPES): + assert isinstance(stmt, NAMED_STMTS_OF_INTEREST_TYPES) + if isinstance(stmt, Interface_Stmt) and not find_name_of_stmt(stmt): + # There can be anonymous blocks, e.g., interface blocks, which cannot be identified. + continue + spec = ident_spec(stmt) + assert spec not in ident_map, f"{spec} / {stmt.parent.parent.parent.parent} / {ident_map[spec].parent.parent.parent.parent}" + ident_map[spec] = stmt + return ident_map + + +def alias_specs(ast: Program): + """ + Maps each "alias-type" identifier of interest in `ast` to its associated node that defines it. + """ + ident_map = identifier_specs(ast) + alias_map: SPEC_TABLE = {k: v for k, v in ident_map.items()} + + for stmt in walk(ast, Use_Stmt): + mod_name = singular(children_of_type(stmt, Name)).string + mod_spec = (mod_name,) + + scope_spec = find_scope_spec(stmt) + use_spec = scope_spec + (mod_name,) + + assert mod_spec in ident_map + # The module's name cannot be used as an identifier in this scope anymore, so just point to the module. + alias_map[use_spec] = ident_map[mod_spec] + + olist = atmost_one(children_of_type(stmt, Only_List)) + if not olist: + # If there is no only list, all the top level (public) symbols are considered aliased. + alias_updates: SPEC_TABLE = {} + for k, v in alias_map.items(): + if len(k) != len(mod_spec) + 1 or k[:len(mod_spec)] != mod_spec: + continue + alias_spec = scope_spec + k[-1:] + alias_updates[alias_spec] = v + alias_map.update(alias_updates) + else: + # Otherwise, only specific identifiers are aliased. + for c in olist.children: + assert isinstance(c, (Name, Rename)) + if isinstance(c, Name): + src, tgt = c, c + elif isinstance(c, Rename): + _, src, tgt = c.children + src, tgt = src.string, tgt.string + src_spec, tgt_spec = scope_spec + (src,), mod_spec + (tgt,) + # `tgt_spec` must have already been resolved if we have sorted the modules properly. + assert tgt_spec in alias_map, f"{src_spec} => {tgt_spec}" + alias_map[src_spec] = alias_map[tgt_spec] + + assert set(ident_map.keys()).issubset(alias_map.keys()) + return alias_map + + +def search_real_ident_spec(ident: str, in_spec: SPEC, alias_map: SPEC_TABLE) -> Optional[SPEC]: + k = in_spec + (ident,) + if k in alias_map: + return ident_spec(alias_map[k]) + if not in_spec: + return None + return search_real_ident_spec(ident, in_spec[:-1], alias_map) + + +def find_real_ident_spec(ident: str, in_spec: SPEC, alias_map: SPEC_TABLE) -> SPEC: + spec = search_real_ident_spec(ident, in_spec, alias_map) + assert spec, f"cannot find {ident} / {in_spec}" + return spec + + +def _find_type_decl_node(node: Entity_Decl): + anc = node.parent + while anc and not atmost_one( + children_of_type(anc, (Intrinsic_Type_Spec, Declaration_Type_Spec))): + anc = anc.parent + return anc + + +def _eval_selected_int_kind(p: np.int32) -> int: + # Copied logic from `replace_int_kind()` elsewhere in the project. + # avoid int overflow in numpy 2.0 + p = int(p) + kind = int(math.ceil((math.log2(10 ** p) + 1) / 8)) + assert kind <= 8 + if kind <= 2: + return kind + elif kind <= 4: + return 4 + return 8 + + +def _eval_selected_real_kind(p: int, r: int) -> int: + # Copied logic from `replace_real_kind()` elsewhere in the project. + if p >= 9 or r > 126: + return 8 + elif p >= 3 or r > 14: + return 4 + return 2 + + +def _const_eval_int(expr: Base, alias_map: SPEC_TABLE) -> Optional[int]: + if isinstance(expr, Name): + scope_spec = find_scope_spec(expr) + spec = find_real_ident_spec(expr.string, scope_spec, alias_map) + decl = alias_map[spec] + assert isinstance(decl, Entity_Decl) + # TODO: Verify that it is a constant expression. + init = atmost_one(children_of_type(decl, Initialization)) + # TODO: Add ref. + _, iexpr = init.children + return _const_eval_int(iexpr, alias_map) + elif isinstance(expr, Intrinsic_Function_Reference): + intr, args = expr.children + if args: + args = args.children + if intr.string == 'SELECTED_REAL_KIND': + assert len(args) == 2 + p, r = args + p, r = _const_eval_int(p, alias_map), _const_eval_int(r, alias_map) + assert p is not None and r is not None + return _eval_selected_real_kind(p, r) + elif intr.string == 'SELECTED_INT_KIND': + assert len(args) == 1 + p, = args + p = _const_eval_int(p, alias_map) + assert p is not None + return _eval_selected_int_kind(p) + elif isinstance(expr, Int_Literal_Constant): + return int(expr.tofortran()) + + # TODO: Add other evaluations. + return None + + +def _cdiv(x, y): + return operator.floordiv(x, y) \ + if (isinstance(x, (np.int8, np.int16, np.int32, np.int64)) + and isinstance(y, (np.int8, np.int16, np.int32, np.int64))) \ + else operator.truediv(x, y) + + +UNARY_OPS = { + '.NOT.': np.logical_not, + '-': operator.neg, +} + +BINARY_OPS = { + '<': operator.le, + '>': operator.ge, + '==': operator.eq, + '/=': operator.ne, + '<=': operator.lt, + '>=': operator.gt, + '+': operator.add, + '-': operator.sub, + '*': operator.mul, + '/': _cdiv, + '.OR.': np.logical_or, + '.AND.': np.logical_and, + '**': operator.pow, +} + +NUMPY_INTS = Union[np.int8, np.int16, np.int32, np.int64] +NUMPY_REALS = Union[np.float32, np.float64] +NUMPY_TYPES = Union[NUMPY_INTS, NUMPY_REALS, np.bool_] + + +def _count_bytes(t: Type[NUMPY_TYPES]) -> int: + if t is np.int8: + return 1 + elif t is np.int16: + return 2 + elif t is np.int32: + return 4 + elif t is np.int64: + return 8 + elif t is np.float32: + return 4 + elif t is np.float64: + return 8 + elif t is np.bool_: + return 1 + raise ValueError(f"{t} is not an expected type; expected {NUMPY_TYPES}") + + +def _eval_int_literal(x: Union[Signed_Int_Literal_Constant, Int_Literal_Constant], alias_map: SPEC_TABLE) -> NUMPY_INTS: + num, kind = x.children + if kind is None: + kind = 4 + elif kind in {'1', '2', '4', '8'}: + kind = np.int32(kind) + else: + kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind,), alias_map) + if kind_spec: + kind_decl = alias_map[kind_spec] + kind_node, _, _, _ = kind_decl.children + kind = _const_eval_basic_type(kind_node, alias_map) + assert isinstance(kind, np.int32) + assert kind in {1, 2, 4, 8} + if kind == 1: + return np.int8(num) + elif kind == 2: + return np.int16(num) + elif kind == 4: + return np.int32(num) + elif kind == 8: + return np.int64(num) + + +def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Constant], + alias_map: SPEC_TABLE) -> NUMPY_REALS: + num, kind = x.children + if kind is None: + if 'D' in num: + num = num.replace('D', 'e') + kind = 8 + else: + kind = 4 + else: + kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind,), alias_map) + if kind_spec: + kind_decl = alias_map[kind_spec] + kind_node, _, _, _ = kind_decl.children + kind = _const_eval_basic_type(kind_node, alias_map) + assert isinstance(kind, np.int32) + assert kind in {4, 8} + if kind == 4: + return np.float32(num) + elif kind == 8: + return np.float64(num) + + +def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_TYPES]: + if isinstance(expr, (Part_Ref, Data_Ref)): + return None + elif isinstance(expr, Name): + spec = search_real_local_alias_spec(expr, alias_map) + if not spec: + # Does not even have a valid identifier. + return None + decl = alias_map[spec] + if not isinstance(decl, Entity_Decl): + # Is not even a data entity. + return None + typ = find_type_of_entity(decl, alias_map) + if not typ or not typ.const or typ.shape: + # Does not have a constant type. + return None + init = atmost_one(children_of_type(decl, Initialization)) + # TODO: Add ref. + _, iexpr = init.children + val = _const_eval_basic_type(iexpr, alias_map) + assert val is not None + if typ.spec == ('INTEGER1',): + val = np.int8(val) + elif typ.spec == ('INTEGER2',): + val = np.int16(val) + elif typ.spec == ('INTEGER4',) or typ.spec == ('INTEGER',): + val = np.int32(val) + elif typ.spec == ('INTEGER8',): + val = np.int64(val) + elif typ.spec == ('REAL4',) or typ.spec == ('REAL',): + val = np.float32(val) + elif typ.spec == ('REAL8',): + val = np.float64(val) + elif typ.spec == ('LOGICAL',): + val = np.bool_(val) + else: + raise ValueError(f"{expr}/{typ.spec} is not a basic type") + return val + elif isinstance(expr, Intrinsic_Function_Reference): + intr, args = expr.children + if args: + args = args.children + if intr.string == 'EPSILON': + a, = args + a = _const_eval_basic_type(a, alias_map) + assert isinstance(a, (np.float32, np.float64)) + return type(a)(sys.float_info.epsilon) + elif intr.string == 'SELECTED_REAL_KIND': + p, r = args + p, r = _const_eval_basic_type(p, alias_map), _const_eval_basic_type(r, alias_map) + assert isinstance(p, np.int32) and isinstance(r, np.int32) + return np.int32(_eval_selected_real_kind(p, r)) + elif intr.string == 'SELECTED_INT_KIND': + p, = args + p = _const_eval_basic_type(p, alias_map) + assert isinstance(p, np.int32) + return np.int32(_eval_selected_int_kind(p)) + elif intr.string == 'INT': + if len(args) == 1: + num, = args + kind = 4 + else: + num, kind = args + kind = _const_eval_basic_type(kind, alias_map) + assert kind is not None + num = _const_eval_basic_type(num, alias_map) + if not num: + return None + return _eval_int_literal(Int_Literal_Constant(f"{num}_{kind}"), alias_map) + elif intr.string == 'REAL': + if len(args) == 1: + num, = args + kind = 4 + else: + num, kind = args + kind = _const_eval_basic_type(kind, alias_map) + assert kind is not None + num = _const_eval_basic_type(num, alias_map) + if not num: + return None + valstr = str(num) + if kind == 8: + if 'e' in valstr: + valstr = valstr.replace('e', 'D') + else: + valstr = f"{valstr}D0" + return _eval_real_literal(Real_Literal_Constant(valstr), alias_map) + elif isinstance(expr, (Int_Literal_Constant, Signed_Int_Literal_Constant)): + return _eval_int_literal(expr, alias_map) + elif isinstance(expr, Logical_Literal_Constant): + return np.bool_(expr.tofortran().upper() == '.TRUE.') + elif isinstance(expr, (Real_Literal_Constant, Signed_Real_Literal_Constant)): + return _eval_real_literal(expr, alias_map) + elif isinstance(expr, BinaryOpBase): + lv, op, rv = expr.children + if op in BINARY_OPS: + lv = _const_eval_basic_type(lv, alias_map) + rv = _const_eval_basic_type(rv, alias_map) + if lv is None or rv is None: + return None + return BINARY_OPS[op](lv, rv) + elif isinstance(expr, UnaryOpBase): + op, val = expr.children + if op in UNARY_OPS: + val = _const_eval_basic_type(val, alias_map) + if val is None: + return None + return UNARY_OPS[op](val) + elif isinstance(expr, Parenthesis): + _, x, _ = expr.children + return _const_eval_basic_type(x, alias_map) + elif isinstance(expr, Hex_Constant): + x = expr.string + assert x[:2] == 'Z"' and x[-1:] == '"' + x = x[2:-1] + return np.int32(int(x, 16)) + + # TODO: Add other evaluations. + return None + + +def find_type_of_entity(node: Entity_Decl, alias_map: SPEC_TABLE) -> Optional[TYPE_SPEC]: + anc = _find_type_decl_node(node) + if not anc: + return None + # TODO: Add ref. + typ, attrs, _ = anc.children + assert isinstance(typ, (Intrinsic_Type_Spec, Declaration_Type_Spec)) + attrs = attrs.tofortran() if attrs else '' + + extra_dim = None + if isinstance(typ, Intrinsic_Type_Spec): + ACCEPTED_TYPES = {'INTEGER', 'REAL', 'DOUBLE PRECISION', 'LOGICAL', 'CHARACTER'} + typ_name, kind = typ.children + assert typ_name in ACCEPTED_TYPES, typ_name + + # TODO: How should we handle character lengths? Just treat it as an extra dimension? + if isinstance(kind, Length_Selector): + assert typ_name == 'CHARACTER' + extra_dim = (':',) + elif isinstance(kind, Kind_Selector): + assert typ_name in {'INTEGER', 'REAL', 'LOGICAL'} + _, kind, _ = kind.children + kind = _const_eval_basic_type(kind, alias_map) or 4 + typ_name = f"{typ_name}{kind}" + elif kind is None: + if typ_name in {'INTEGER', 'REAL'}: + typ_name = f"{typ_name}4" + elif typ_name in {'DOUBLE PRECISION'}: + typ_name = f"REAL8" + spec = (typ_name,) + elif isinstance(typ, Declaration_Type_Spec): + _, typ_name = typ.children + spec = find_real_ident_spec(typ_name.string, ident_spec(node), alias_map) + + # TODO: This `attrs` manipulation is a hack. We should design the type specs better. + # TODO: Add ref. + attrs = [attrs] if attrs else [] + _, shape, _, _ = node.children + if shape is not None: + attrs.append(f"DIMENSION({shape.tofortran()})") + attrs = ', '.join(attrs) + tspec = TYPE_SPEC(spec, attrs) + if extra_dim: + tspec.shape += extra_dim + return tspec + + +def _dataref_root(dref: Union[Name, Data_Ref], scope_spec: SPEC, alias_map: SPEC_TABLE): + if isinstance(dref, Name): + root, rest = dref, [] + else: + assert len(dref.children) >= 2 + root, rest = dref.children[0], dref.children[1:] + + if isinstance(root, Name): + root_spec = find_real_ident_spec(root.string, scope_spec, alias_map) + assert root_spec in alias_map, f"canont find: {root_spec} / {dref} in {scope_spec}" + root_type = find_type_of_entity(alias_map[root_spec], alias_map) + elif isinstance(root, Data_Ref): + root_type = find_type_dataref(root, scope_spec, alias_map) + assert root_type + + return root, root_type, rest + + +def find_dataref_component_spec(dref: Union[Name, Data_Ref], scope_spec: SPEC, alias_map: SPEC_TABLE) -> SPEC: + # The root must have been a typed object. + _, root_type, rest = _dataref_root(dref, scope_spec, alias_map) + + cur_type = root_type + # All component shards except for the last one must have been type objects too. + for comp in rest[:-1]: + assert isinstance(comp, (Name, Part_Ref)) + if isinstance(comp, Part_Ref): + part_name, _ = comp.children[0], comp.children[1:] + comp_spec = find_real_ident_spec(part_name.string, cur_type.spec, alias_map) + elif isinstance(comp, Name): + comp_spec = find_real_ident_spec(comp.string, cur_type.spec, alias_map) + assert comp_spec in alias_map, f"canont find: {comp_spec} / {dref} in {scope_spec}" + # So, we get the type spec for those component shards. + cur_type = find_type_of_entity(alias_map[comp_spec], alias_map) + assert cur_type + + # For the last one, we just need the component spec. + comp = rest[-1] + assert isinstance(comp, (Name, Part_Ref)) + if isinstance(comp, Part_Ref): + part_name, _ = comp.children[0], comp.children[1:] + comp_spec = find_real_ident_spec(part_name.string, cur_type.spec, alias_map) + elif isinstance(comp, Name): + comp_spec = find_real_ident_spec(comp.string, cur_type.spec, alias_map) + assert comp_spec in alias_map, f"canont find: {comp_spec} / {dref} in {scope_spec}" + + return comp_spec + + +def find_type_dataref(dref: Union[Name, Part_Ref, Data_Ref], scope_spec: SPEC, alias_map: SPEC_TABLE) -> TYPE_SPEC: + _, root_type, rest = _dataref_root(dref, scope_spec, alias_map) + cur_type = root_type + for comp in rest: + assert isinstance(comp, (Name, Part_Ref)) + if isinstance(comp, Part_Ref): + # TODO: Add ref. + part_name, subsc = comp.children + comp_spec = find_real_ident_spec(part_name.string, cur_type.spec, alias_map) + assert comp_spec in alias_map, f"cannot find {comp_spec} / {dref} in {scope_spec}" + cur_type = find_type_of_entity(alias_map[comp_spec], alias_map) + if not cur_type.shape: + # The object was not an array in the first place. + assert not subsc, f"{cur_type} / {part_name}, {cur_type.spec}, {comp}" + elif subsc: + # TODO: This is a hack to deduce a array type instead of scalar. + # We may have subscripted away all the dimensions. + cur_type.shape = tuple(s.tofortran() for s in subsc.children if ':' in s.tofortran()) + elif isinstance(comp, Name): + comp_spec = find_real_ident_spec(comp.string, cur_type.spec, alias_map) + assert comp_spec in alias_map, f"cannot find {comp_spec} / {dref} in {scope_spec}" + cur_type = find_type_of_entity(alias_map[comp_spec], alias_map) + assert cur_type + return cur_type + + +def procedure_specs(ast: Program) -> Dict[SPEC, SPEC]: + proc_map: Dict[SPEC, SPEC] = {} + for pb in walk(ast, Specific_Binding): + # Ref: https://github.com/stfc/fparser/blob/8c870f84edbf1a24dfbc886e2f7226d1b158d50b/src/fparser/two/Fortran2003.py#L2504 + iname, mylist, dcolon, bname, pname = pb.children + + proc_spec, subp_spec = [bname.string], [pname.string if pname else bname.string] + + typedef: Derived_Type_Def = pb.parent.parent + typedef_stmt: Derived_Type_Stmt = singular(children_of_type(typedef, Derived_Type_Stmt)) + typedef_name: str = singular(children_of_type(typedef_stmt, Type_Name)).string + proc_spec.insert(0, typedef_name) + + # TODO: Generalize. + # We assume that the type is defined inside a module (i.e., not another subprogram). + mod: Module = typedef.parent.parent + mod_stmt: Module_Stmt = singular(children_of_type(mod, (Module_Stmt, Program_Stmt))) + # TODO: Add ref. + _, mod_name = mod_stmt.children + proc_spec.insert(0, mod_name.string) + subp_spec.insert(0, mod_name.string) + + # TODO: Is this assumption true? + # We assume that the type and the bound function exist in the same scope (i.e., module, subprogram etc.). + proc_map[tuple(proc_spec)] = tuple(subp_spec) + return proc_map + + +def generic_specs(ast: Program) -> Dict[SPEC, Tuple[SPEC, ...]]: + genc_map: Dict[SPEC, Tuple[SPEC, ...]] = {} + for gb in walk(ast, Generic_Binding): + # TODO: Add ref. + aspec, bname, plist = gb.children + if plist: + plist = plist.children + else: + plist = [] + + scope_spec = find_scope_spec(gb) + genc_spec = scope_spec + (bname.string,) + + proc_specs = [] + for pname in plist: + pspec = scope_spec + (pname.string,) + proc_specs.append(pspec) + + # TODO: Is this assumption true? + # We assume that the type and the bound function exist in the same scope (i.e., module, subprogram etc.). + genc_map[tuple(genc_spec)] = tuple(proc_specs) + return genc_map + + +def interface_specs(ast: Program, alias_map: SPEC_TABLE) -> Dict[SPEC, Tuple[SPEC, ...]]: + iface_map: Dict[SPEC, Tuple[SPEC, ...]] = {} + + # First, we deal with named interface blocks. + for ifs in walk(ast, Interface_Stmt): + name = find_name_of_stmt(ifs) + if not name: + # Only named interfaces can be called. + continue + ib = ifs.parent + scope_spec = find_scope_spec(ib) + ifspec = scope_spec + (name,) + + # Get the spec of all the callable things in this block that may end up as a resolution for this interface. + fns: List[str] = [] + for fn in walk(ib, (Function_Stmt, Subroutine_Stmt, Procedure_Stmt)): + if isinstance(fn, (Function_Stmt, Subroutine_Stmt)): + fns.append(find_name_of_stmt(fn)) + elif isinstance(fn, Procedure_Stmt): + for nm in walk(fn, Name): + fns.append(nm.string) + + fn_specs = tuple(find_real_ident_spec(f, scope_spec, alias_map) for f in fns) + assert ifspec not in fn_specs + iface_map[ifspec] = fn_specs + + # Then, we try to resolve anonymous interface blocks' content onto their parents' scopes. + for ifs in walk(ast, Interface_Stmt): + name = find_name_of_stmt(ifs) + if name: + # Only anonymous interface blocks. + continue + ib = ifs.parent + scope_spec = find_scope_spec(ib) + assert not walk(ib, Procedure_Stmt) + + # Get the spec of all the callable things in this block that may end up as a resolution for this interface. + for fn in walk(ib, (Function_Stmt, Subroutine_Stmt)): + fn_name = find_name_of_stmt(fn) + ifspec = ident_spec(fn) + cscope = scope_spec + fn_spec = find_real_ident_spec(fn_name, cscope, alias_map) + # If we are resolving the interface back to itself, we need to search a level above. + while ifspec == fn_spec: + assert cscope + cscope = cscope[:-1] + fn_spec = find_real_ident_spec(fn_name, cscope, alias_map) + assert ifspec != fn_spec + iface_map[ifspec] = (fn_spec,) + + return iface_map + + +def set_children(par: Base, children: Iterable[Base]): + assert hasattr(par, 'content') != hasattr(par, 'items') + if hasattr(par, 'items'): + par.items = tuple(children) + elif hasattr(par, 'content'): + par.content = list(children) + _reparent_children(par) + + +def replace_node(node: Base, subst: Union[Base, Iterable[Base]]): + # A lot of hacky stuff to make sure that the new nodes are not just the same objects over and over. + par = node.parent + only_child = bool([c for c in par.children if c == node]) + repls = [] + for c in par.children: + if c != node: + repls.append(c) + continue + if isinstance(subst, Base): + subst = [subst] + if not only_child: + subst = [Base.__new__(type(t), t.tofortran()) for t in subst] + repls.extend(subst) + if isinstance(par, Loop_Control) and isinstance(subst, Base): + _, cntexpr, _, _ = par.children + if cntexpr: + loopvar, looprange = cntexpr + for i in range(len(looprange)): + if looprange[i] == node: + looprange[i] = subst + subst.parent = par + set_children(par, repls) + + +def append_children(par: Base, children: Union[Base, List[Base]]): + if isinstance(children, Base): + children = [children] + set_children(par, list(par.children) + children) + + +def prepend_children(par: Base, children: Union[Base, List[Base]]): + if isinstance(children, Base): + children = [children] + set_children(par, children + list(par.children)) + + +def remove_children(par: Base, children: Union[Base, List[Base]]): + if isinstance(children, Base): + children = [children] + repl = [c for c in par.children if c not in children] + set_children(par, repl) + + +def remove_self(nodes: Union[Base, List[Base]]): + if isinstance(nodes, Base): + nodes = [nodes] + for n in nodes: + remove_children(n.parent, n) + + +def correct_for_function_calls(ast: Program): + """Look for function calls that may have been misidentified as array access and fix them.""" + alias_map = alias_specs(ast) + + # TODO: Looping over and over is not ideal. But `Function_Reference(...)` sometimes generate inner `Part_Ref`s. We + # should figure out a way to avoid this clutter. + changed = True + while changed: + changed = False + for pr in walk(ast, Part_Ref): + scope_spec = find_scope_spec(pr) + if isinstance(pr.parent, Data_Ref): + dref = pr.parent + comp_spec = find_dataref_component_spec(dref, scope_spec, alias_map) + comp_type_spec = find_type_of_entity(alias_map[comp_spec], alias_map) + if not comp_type_spec: + # Cannot find a type, so it must be a function call. + replace_node(dref, Function_Reference(dref.tofortran())) + changed = True + else: + pr_name, _ = pr.children + if isinstance(pr_name, Name): + pr_spec = search_real_local_alias_spec(pr_name, alias_map) + if pr_spec in alias_map and isinstance(alias_map[pr_spec], (Function_Stmt, Interface_Stmt)): + replace_node(pr, Function_Reference(pr.tofortran())) + changed = True + elif isinstance(pr_name, Data_Ref): + pr_type_spec = find_type_dataref(pr_name, scope_spec, alias_map) + if not pr_type_spec: + # Cannot find a type, so it must be a function call. + replace_node(pr, Function_Reference(pr.tofortran())) + changed = True + + for sc in walk(ast, Structure_Constructor): + scope_spec = find_scope_spec(sc) + + # TODO: Add ref. + sc_type, _ = sc.children + sc_type_spec = find_real_ident_spec(sc_type.string, scope_spec, alias_map) + if isinstance(alias_map[sc_type_spec], (Function_Stmt, Interface_Stmt)): + # Now we know that this identifier actually refers to a function. + replace_node(sc, Function_Reference(sc.tofortran())) + + # These can also be intrinsic function calls. + for fref in walk(ast, (Function_Reference, Call_Stmt)): + scope_spec = find_scope_spec(fref) + + name, args = fref.children + name = name.string + if not Intrinsic_Name.match(name): + # There is no way this is an intrinsic call. + continue + fref_spec = scope_spec + (name,) + if fref_spec in alias_map: + # This is already an alias, so intrinsic object is shadowed. + continue + if isinstance(fref, Function_Reference): + # We need to replace with this exact node structure, and cannot rely on FParser to parse it right. + repl = Intrinsic_Function_Reference(fref.tofortran()) + # Set the arguments ourselves, just in case the parser messes it up. + repl.items = (Intrinsic_Name(name), args) + _reparent_children(repl) + replace_node(fref, repl) + else: + fref.items = (Intrinsic_Name(name), args) + _reparent_children(fref) + + return ast + + +def remove_access_statements(ast: Program): + """Look for public/private access statements and just remove them.""" + # TODO: This can get us into ambiguity and unintended shadowing. + + # We also remove any access statement that makes these interfaces public/private. + for acc in walk(ast, Access_Stmt): + # TODO: Add ref. + kind, alist = acc.children + assert kind.upper() in {'PUBLIC', 'PRIVATE'} + spec = acc.parent + remove_self(acc) + if not spec.children: + remove_self(spec) + + return ast + + +def sort_modules(ast: Program) -> Program: + TOPLEVEL = '__toplevel__' + + def _get_module(n: Base) -> str: + p = n + while p and not isinstance(p, (Module, Main_Program)): + p = p.parent + if not p: + return TOPLEVEL + else: + p = singular(children_of_type(p, (Module_Stmt, Program_Stmt))) + return find_name_of_stmt(p) + + g = nx.DiGraph() # An edge u->v means u should come before v, i.e., v depends on u. + for c in ast.children: + g.add_node(_get_module(c)) + + for u in walk(ast, Use_Stmt): + u_name = singular(children_of_type(u, Name)).string + v_name = _get_module(u) + g.add_edge(u_name, v_name) + + top_ord = {n: i for i, n in enumerate(nx.lexicographical_topological_sort(g))} + # We keep the top-level subroutines at the end. It is only a cosmetic choice and fortran accepts them anywhere. + top_ord[TOPLEVEL] = len(top_ord) + 1 + assert all(_get_module(n) in top_ord for n in ast.children) + ast.content = sorted(ast.children, key=lambda x: top_ord[_get_module(x)]) + + return ast + + +def deconstruct_enums(ast: Program) -> Program: + for en in walk(ast, Enum_Def): + en_dict: Dict[str, Expr] = {} + # We need to for automatic counting. + next_val = '0' + next_offset = 0 + for el in walk(en, Enumerator_List): + for c in el.children: + if isinstance(c, Name): + c_name = c.string + elif isinstance(c, Enumerator): + # TODO: Add ref. + name, _, val = c.children + c_name = name.string + next_val = val.string + next_offset = 0 + en_dict[c_name] = Expr(f"{next_val} + {next_offset}") + next_offset = next_offset + 1 + type_decls = [Type_Declaration_Stmt(f"integer, parameter :: {k} = {v}") for k, v in en_dict.items()] + replace_node(en, [Type_Declaration_Stmt(f"integer, parameter :: {k} = {v}") for k, v in en_dict.items()]) + return ast + + +def _compute_argument_signature(args, scope_spec: SPEC, alias_map: SPEC_TABLE) -> Tuple[TYPE_SPEC, ...]: + if not args: + return tuple() + + args_sig = [] + for c in args.children: + def _deduct_type(x) -> TYPE_SPEC: + if isinstance(x, (Real_Literal_Constant, Signed_Real_Literal_Constant)): + return TYPE_SPEC('REAL') + elif isinstance(x, (Int_Literal_Constant, Signed_Int_Literal_Constant)): + val = _eval_int_literal(x, alias_map) + assert isinstance(val, NUMPY_INTS) + return TYPE_SPEC(f"INTEGER{_count_bytes(type(val))}") + elif isinstance(x, Char_Literal_Constant): + str_typ = TYPE_SPEC('CHARACTER', 'DIMENSION(:)') + return str_typ + elif isinstance(x, Logical_Literal_Constant): + return TYPE_SPEC('LOGICAL') + elif isinstance(x, Name): + x_spec = find_real_ident_spec(x.string, scope_spec, alias_map) + assert x_spec in alias_map, f"cannot find: {x_spec} / {x}" + x_type = find_type_of_entity(alias_map[x_spec], alias_map) + assert x_type, f"cannot find type for: {x_spec} / x" + # TODO: This is a hack to make the array etc. types different. + return x_type + elif isinstance(x, Data_Ref): + return find_type_dataref(x, scope_spec, alias_map) + elif isinstance(x, Part_Ref): + # TODO: Add ref. + part_name, subsc = x.children + orig_type = find_type_dataref(part_name, scope_spec, alias_map) + if not orig_type.shape: + # The object was not an array in the first place. + assert not subsc, f"{orig_type} / {part_name}, {scope_spec}, {x}" + return orig_type + if not subsc: + # No further subscription, so retain the original type of the object. + return orig_type + # TODO: This is a hack to deduce a array type instead of scalar. + # We may have subscripted away all the dimensions. + subsc = subsc.children + # TODO: Can we avoid padding the missing dimensions? This happens when the type itself is array-ish too. + subsc = tuple([Section_Subscript(':')] * (len(orig_type.shape) - len(subsc))) + subsc + assert len(subsc) == len(orig_type.shape) + orig_type.shape = tuple(s.tofortran() for s in subsc if ':' in s.tofortran()) + return orig_type + elif isinstance(x, Actual_Arg_Spec): + kw, val = x.children + t = _deduct_type(val) + if isinstance(kw, Name): + t.keyword = kw.string + return t + elif isinstance(x, Intrinsic_Function_Reference): + fname, args = x.children + if args: + args = args.children + if fname.string in {'TRIM'}: + return TYPE_SPEC('CHARACTER', 'DIMENSION(:)') + elif fname.string in {'SIZE'}: + return TYPE_SPEC('INTEGER') + elif fname.string in {'REAL'}: + assert 1 <= len(args) <= 2 + kind = None + if len(args) == 2: + kind = _const_eval_int(args[-1], alias_map) + if kind: + return TYPE_SPEC(f"REAL{kind}") + else: + return TYPE_SPEC('REAL') + elif fname.string in {'INT'}: + assert 1 <= len(args) <= 2 + kind = None + if len(args) == 2: + kind = _const_eval_int(args[-1], alias_map) + if kind: + return TYPE_SPEC(f"INTEGER{kind}") + else: + return TYPE_SPEC('INTEGER') + # TODO: Figure out the actual type. + return MATCH_ALL + elif isinstance(x, (Level_2_Unary_Expr, And_Operand)): + op, dref = x.children + if op in {'+', '-', '.NOT.'}: + return _deduct_type(dref) + # TODO: Figure out the actual type. + return MATCH_ALL + elif isinstance(x, Parenthesis): + _, exp, _ = x.children + return _deduct_type(exp) + elif isinstance(x, (Level_2_Expr, Level_3_Expr)): + lval, op, rval = x.children + if op in {'+', '-'}: + tl, tr = _deduct_type(lval), _deduct_type(rval) + if len(tl.shape) < len(tr.shape): + return tr + else: + return tl + elif op in {'//'}: + return TYPE_SPEC('CHARACTER', 'DIMENSION(:)') + # TODO: Figure out the actual type. + return MATCH_ALL + elif isinstance(x, Array_Constructor): + b, items, e = x.children + items = items.children + # TODO: We are assuming there is an element. What if there isn't? + t = _deduct_type(items[0]) + t.shape += (':',) + return t + else: + # TODO: Figure out the actual type. + return MATCH_ALL + + c_type = _deduct_type(c) + assert c_type, f"got: {c} / {type(c)}" + args_sig.append(c_type) + + return tuple(args_sig) + + +def _compute_candidate_argument_signature(args, cand_spec: SPEC, alias_map: SPEC_TABLE) -> Tuple[TYPE_SPEC, ...]: + cand_args_sig: List[TYPE_SPEC] = [] + for ca in args: + ca_decl = alias_map[cand_spec + (ca.string,)] + ca_type = find_type_of_entity(ca_decl, alias_map) + ca_type.keyword = ca.string + assert ca_type, f"got: {ca} / {type(ca)}" + cand_args_sig.append(ca_type) + return tuple(cand_args_sig) + + +def deconstruct_interface_calls(ast: Program) -> Program: + SUFFIX, COUNTER = 'deconiface', 0 + + alias_map = alias_specs(ast) + iface_map = interface_specs(ast, alias_map) + + for fref in walk(ast, (Function_Reference, Call_Stmt)): + scope_spec = find_scope_spec(fref) + name, args = fref.children + if isinstance(name, Intrinsic_Name): + continue + fref_spec = find_real_ident_spec(name.string, scope_spec, alias_map) + assert fref_spec in alias_map, f"cannot find: {fref_spec}" + if fref_spec not in iface_map: + # We are only interested in calls to interfaces here. + continue + + # Find the nearest execution and its correpsonding specification parts. + execution_part = fref.parent + while not isinstance(execution_part, Execution_Part): + execution_part = execution_part.parent + subprog = execution_part.parent + specification_part = atmost_one(children_of_type(subprog, Specification_Part)) + + ifc_spec = ident_spec(alias_map[fref_spec]) + args_sig: Tuple[TYPE_SPEC, ...] = _compute_argument_signature(args, scope_spec, alias_map) + all_cand_sigs: List[Tuple[SPEC, Tuple[TYPE_SPEC, ...]]] = [] + + conc_spec = None + for cand in iface_map[ifc_spec]: + assert cand in alias_map + cand_stmt = alias_map[cand] + assert isinstance(cand_stmt, (Function_Stmt, Subroutine_Stmt)) + + # However, this candidate could be inside an interface block, and this be just another level of indirection. + cand_spec = cand + if isinstance(cand_stmt.parent.parent, Interface_Block): + cand_spec = find_real_ident_spec(cand_spec[-1], cand_spec[:-2], alias_map) + assert cand_spec in alias_map + cand_stmt = alias_map[cand_spec] + assert isinstance(cand_stmt, (Function_Stmt, Subroutine_Stmt)) + + # TODO: Add ref. + _, _, cand_args, _ = cand_stmt.children + if cand_args: + cand_args_sig = _compute_candidate_argument_signature(cand_args.children, cand_spec, alias_map) + else: + cand_args_sig = tuple() + all_cand_sigs.append((cand_spec, cand_args_sig)) + + if _does_type_signature_match(args_sig, cand_args_sig): + conc_spec = cand_spec + break + if conc_spec not in alias_map: + print(f"{ifc_spec}/{conc_spec} / {args_sig}") + for c in all_cand_sigs: + print(f"...> {c}") + assert conc_spec and conc_spec in alias_map, f"[in: {fref_spec}] {ifc_spec}/{conc_spec} not found" + + # We are assumping that it's either a toplevel subprogram or a subprogram defined directly inside a module. + assert 1 <= len(conc_spec) <= 2 + if len(conc_spec) == 1: + mod, pname = None, conc_spec[0] + else: + mod, pname = conc_spec + + if mod is None or mod == scope_spec[0]: + # Since `pname` must have been already defined at either the top level or the module level, there is no need + # for aliasing. + pname_alias = pname + else: + # If we are importing it from a different module, we should create an alias to avoid name collision. + pname_alias, COUNTER = f"{pname}_{SUFFIX}_{COUNTER}", COUNTER + 1 + if not specification_part: + append_children(subprog, Specification_Part(get_reader(f"use {mod}, only: {pname_alias} => {pname}"))) + else: + prepend_children(specification_part, Use_Stmt(f"use {mod}, only: {pname_alias} => {pname}")) + + # For both function and subroutine calls, replace `bname` with `pname_alias`, and add `dref` as the first arg. + replace_node(name, Name(pname_alias)) + + # TODO: Figure out a way without rebuilding here. + # Rebuild the maps because aliasing may have changed. + alias_map = alias_specs(ast) + + # At this point, we must have replaced all the interface calls with concrete calls. + for use in walk(ast, Use_Stmt): + mod_name = singular(children_of_type(use, Name)).string + mod_spec = (mod_name,) + olist = atmost_one(children_of_type(use, Only_List)) + if not olist: + # There is nothing directly referring to the interface. + continue + + survivors = [] + for c in olist.children: + assert isinstance(c, (Name, Rename)) + if isinstance(c, Name): + src, tgt = c, c + elif isinstance(c, Rename): + _, src, tgt = c.children + src, tgt = src.string, tgt.string + tgt_spec = find_real_ident_spec(tgt, mod_spec, alias_map) + assert tgt_spec in alias_map + if tgt_spec not in iface_map: + # Leave the non-interface usages alone. + survivors.append(c) + + if survivors: + olist.items = survivors + _reparent_children(olist) + else: + remove_self(use) + + # We also remove any access statement that makes these interfaces public/private. + for acc in walk(ast, Access_Stmt): + # TODO: Add ref. + kind, alist = acc.children + if not alist: + continue + scope_spec = find_scope_spec(acc) + + survivors = [] + for c in alist.children: + assert isinstance(c, Name) + c_spec = scope_spec + (c.string,) + assert c_spec in alias_map + if not isinstance(alias_map[c_spec], Interface_Stmt): + # Leave the non-interface usages alone. + survivors.append(c) + + if survivors: + alist.items = survivors + _reparent_children(alist) + else: + remove_self(acc) + + # At this point, we must have replaced all references to the interfaces. + for k in iface_map.keys(): + assert k in alias_map + ib = None + if isinstance(alias_map[k], Interface_Stmt): + ib = alias_map[k].parent + elif isinstance(alias_map[k], (Function_Stmt, Subroutine_Stmt)): + ib = alias_map[k].parent.parent + assert isinstance(ib, Interface_Block) + remove_self(ib) + + return ast + + +MATCH_ALL = TYPE_SPEC(('*',), '') # TODO: Hacky; `_does_type_signature_match()` will match anything with this. + + +def _does_part_matches(g: TYPE_SPEC, c: TYPE_SPEC) -> bool: + if c == MATCH_ALL: + # Consider them matched. + return True + if len(g.shape) != len(c.shape): + # Both's ranks must match + return False + + def _real_num_type(t: str) -> Tuple[str, int]: + if t == 'DOUBLE PRECISION': + return 'REAL', 8 + elif t == 'REAL': + return 'REAL', 4 + elif t.startswith('REAL'): + w = int(t.removeprefix('REAL')) + return 'REAL', w + elif t == 'INTEGER': + return 'INTEGER', 4 + elif t.startswith('INTEGER'): + w = int(t.removeprefix('INTEGER')) + return 'INTEGER', w + return t, 1 + + def _subsumes(b: SPEC, s: SPEC) -> bool: + """If `b` subsumes `s`.""" + if b == s: + return True + if len(b) != 1 or len(s) != 1: + # TODO: We don't know how to evaluate this? + return False + b, s = b[0], s[0] + b, bw = _real_num_type(b) + s, sw = _real_num_type(s) + return b == s and bw >= sw + + return _subsumes(c.spec, g.spec) + + +def _does_type_signature_match(got_sig: Tuple[TYPE_SPEC, ...], cand_sig: Tuple[TYPE_SPEC, ...]): + # Assumptions (Fortran rules): + # 1. `got_sig` will not have any positional argument after keyworded arguments start. + # 2. `got_sig` may have keyworded arguments that are actually required arguments, and in different orders. + # 3. `got_sig` will not have any repeated keywords. + + got_pos, got_kwd = tuple(x for x in got_sig if not x.keyword), {x.keyword: x for x in got_sig if x.keyword} + if len(got_sig) > len(cand_sig): + # Cannot have more arguments than needed. + return False + + cand_pos, cand_kwd = cand_sig[:len(got_pos)], {x.keyword: x for x in cand_sig[len(got_pos):]} + # Positional arguments are must all match in order. + for c, g in zip(cand_pos, got_pos): + if not _does_part_matches(g, c): + return False + # Now, we just need to check if `cand_kwd` matches `got_kwd`. + + # All the provided keywords must show up and match in the candidate list. + for k, g in got_kwd.items(): + if k not in cand_kwd or not _does_part_matches(g, cand_kwd[k]): + return False + # All the required candidates must have been provided as keywords. + for k, c in cand_kwd.items(): + if c.optional: + continue + if k not in got_kwd or not _does_part_matches(got_kwd[k], c): + return False + return True + + +def deconstruct_procedure_calls(ast: Program) -> Program: + SUFFIX, COUNTER = 'deconproc', 0 + + alias_map = alias_specs(ast) + proc_map = procedure_specs(ast) + genc_map = generic_specs(ast) + # We should have removed all the `association`s by now. + assert not walk(ast, Association), f"{walk(ast, Association)}" + + for pd in walk(ast, Procedure_Designator): + # Ref: https://github.com/stfc/fparser/blob/master/src/fparser/two/Fortran2003.py#L12530 + dref, op, bname = pd.children + + callsite = pd.parent + assert isinstance(callsite, (Function_Reference, Call_Stmt)) + + # Find out the module name. + cmod = callsite.parent + while cmod and not isinstance(cmod, (Module, Main_Program)): + cmod = cmod.parent + if cmod: + stmt, _, _, _ = _get_module_or_program_parts(cmod) + cmod = singular(children_of_type(stmt, Name)).string.lower() + else: + subp = list(children_of_type(ast, Subroutine_Subprogram)) + assert subp + stmt = singular(children_of_type(subp[0], Subroutine_Stmt)) + cmod = singular(children_of_type(stmt, Name)).string.lower() + + # Find the nearest execution and its correpsonding specification parts. + execution_part = callsite.parent + while not isinstance(execution_part, Execution_Part): + execution_part = execution_part.parent + subprog = execution_part.parent + specification_part = atmost_one(children_of_type(subprog, Specification_Part)) + + scope_spec = find_scope_spec(callsite) + dref_type = find_type_dataref(dref, scope_spec, alias_map) + fnref = pd.parent + assert isinstance(fnref, (Function_Reference, Call_Stmt)) + _, args = fnref.children + args_sig: Tuple[TYPE_SPEC, ...] = _compute_argument_signature(args, scope_spec, alias_map) + all_cand_sigs: List[Tuple[SPEC, Tuple[TYPE_SPEC, ...]]] = [] + + bspec = dref_type.spec + (bname.string,) + if bspec in genc_map and genc_map[bspec]: + for cand in genc_map[bspec]: + cand_stmt = alias_map[proc_map[cand]] + cand_spec = ident_spec(cand_stmt) + # TODO: Add ref. + _, _, cand_args, _ = cand_stmt.children + if cand_args: + cand_args_sig = _compute_candidate_argument_signature(cand_args.children[1:], cand_spec, alias_map) + else: + cand_args_sig = tuple() + all_cand_sigs.append((cand_spec, cand_args_sig)) + + if _does_type_signature_match(args_sig, cand_args_sig): + bspec = cand + break + if bspec not in proc_map: + print(f"{bspec} / {args_sig}") + for c in all_cand_sigs: + print(f"...> {c}") + assert bspec in proc_map, f"[in mod: {cmod}/{callsite}] {bspec} not found" + pname = proc_map[bspec] + + # We are assumping that it's a subprogram defined directly inside a module. + assert len(pname) == 2 + mod, pname = pname + + if mod == cmod: + # Since `pname` must have been already defined at the module level, there is no need for aliasing. + pname_alias = pname + else: + # If we are importing it from a different module, we should create an alias to avoid name collision. + pname_alias, COUNTER = f"{pname}_{SUFFIX}_{COUNTER}", COUNTER + 1 + if not specification_part: + append_children(subprog, Specification_Part(get_reader(f"use {mod}, only: {pname_alias} => {pname}"))) + else: + prepend_children(specification_part, Use_Stmt(f"use {mod}, only: {pname_alias} => {pname}")) + + # For both function and subroutine calls, replace `bname` with `pname_alias`, and add `dref` as the first arg. + _, args = callsite.children + if args is None: + args = Actual_Arg_Spec_List(f"{dref}") + else: + args = Actual_Arg_Spec_List(f"{dref}, {args}") + callsite.items = (Name(pname_alias), args) + _reparent_children(callsite) + + for tbp in walk(ast, Type_Bound_Procedure_Part): + remove_self(tbp) + return ast + + +def _reparent_children(node: Base): + """Make `node` a parent of all its children, in case it isn't already.""" + for c in node.children: + if isinstance(c, Base): + c.parent = node + + +def prune_unused_objects(ast: Program, keepers: List[SPEC]) -> Program: + """ + Precondition: All the indirections have been taken out of the program. + """ + PRUNABLE_OBJECT_TYPES = Union[Main_Program, Subroutine_Subprogram, Function_Subprogram, Derived_Type_Def] + + ident_map = identifier_specs(ast) + alias_map = alias_specs(ast) + survivors: Set[SPEC] = set() + keepers = [alias_map[k].parent for k in keepers] + assert all(isinstance(k, PRUNABLE_OBJECT_TYPES) for k in keepers) + + def _keep_from(node: Base): + for nm in walk(node, Name): + ob = nm.parent + sc_spec = search_scope_spec(ob) + if not sc_spec: + continue + + for j in reversed(range(len(sc_spec))): + anc = sc_spec[:j + 1] + if anc in survivors: + continue + survivors.add(anc) + anc_node = alias_map[anc].parent + if isinstance(anc_node, PRUNABLE_OBJECT_TYPES): + _keep_from(anc_node) + + to_keep = search_real_ident_spec(nm.string, sc_spec, alias_map) + if not to_keep or to_keep not in alias_map or to_keep in survivors: + # If we don't have a valid `to_keep` or `to_keep` is already kept, we move on. + continue + survivors.add(to_keep) + keep_node = alias_map[to_keep].parent + if isinstance(keep_node, PRUNABLE_OBJECT_TYPES): + _keep_from(keep_node) + + for k in keepers: + _keep_from(k) + + # We keep them sorted so that the parent scopes are handled earlier. + killed: Set[SPEC] = set() + for ns in list(sorted(set(ident_map.keys()) - survivors)): + ns_node = ident_map[ns].parent + if not isinstance(ns_node, PRUNABLE_OBJECT_TYPES): + continue + for i in range(len(ns) - 1): + anc_spec = ns[:i + 1] + if anc_spec in killed: + killed.add(ns) + break + if ns in killed: + continue + remove_self(ns_node) + killed.add(ns) + + # We also remove any access statement that makes the killed objects public/private. + for acc in walk(ast, Access_Stmt): + # TODO: Add ref. + kind, alist = acc.children + if not alist: + continue + scope_spec = find_scope_spec(acc) + good_children = [] + for c in alist.children: + assert isinstance(c, Name) + c_spec = find_real_ident_spec(c.string, scope_spec, alias_map) + assert c_spec in ident_map + if c_spec not in killed: + good_children.append(c) + if good_children: + alist.items = good_children + _reparent_children(alist) + else: + remove_self(acc) + + return ast + + +def deconstruct_associations(ast: Program) -> Program: + for assoc in walk(ast, Associate_Construct): + # TODO: Add ref. + stmt, rest, _ = assoc.children[0], assoc.children[1:-1], assoc.children[-1] + # TODO: Add ref. + kw, assoc_list = stmt.children[0], stmt.children[1:] + if not assoc_list: + continue + + # Keep track of what to replace in the local scope. + local_map: Dict[str, Base] = {} + for al in assoc_list: + for a in al.children: + # TODO: Add ref. + src, _, tgt = a.children + local_map[src.string] = tgt + + for node in rest: + # Replace the data-ref roots as appropriate. + for dr in walk(node, Data_Ref): + # TODO: Add ref. + root, dr_rest = dr.children[0], dr.children[1:] + if root.string in local_map: + repl = local_map[root.string] + repl = type(repl)(repl.tofortran()) + dr.items = (repl, *dr_rest) + _reparent_children(dr) + # # Replace the part-ref roots as appropriate. + for pr in walk(node, Part_Ref): + if isinstance(pr.parent, (Data_Ref, Part_Ref)): + continue + # TODO: Add ref. + root, subsc = pr.children + if root.string in local_map: + repl = local_map[root.string] + repl = type(repl)(repl.tofortran()) + if isinstance(subsc, Section_Subscript_List) and isinstance(repl, (Data_Ref, Part_Ref)): + access = repl + while isinstance(access, (Data_Ref, Part_Ref)): + access = access.children[-1] + if isinstance(access, Section_Subscript_List): + # We cannot just chain accesses, so we need to combine them to produce a single access. + # TODO: Maybe `isinstance(c, Subscript_Triplet)` + offset manipulation? + free_comps = [(i, c) for i, c in enumerate(access.children) if c == Subscript_Triplet(':')] + assert len(free_comps) >= len(subsc.children), \ + f"Free rank cannot increase, got {root}/{access} => {subsc}" + for i, c in enumerate(subsc.children): + idx, _ = free_comps[i] + free_comps[i] = (idx, c) + free_comps = {i: c for i, c in free_comps} + access.items = [free_comps.get(i, c) for i, c in enumerate(access.children)] + # Now replace the entire `pr` with `repl`. + replace_node(pr, repl) + continue + # Otherwise, just replace normally. + pr.items = (repl, subsc) + _reparent_children(pr) + # Replace all the other names. + for nm in walk(node, Name): + # TODO: This is hacky and can backfire if `nm` is not a standalone identifier. + par = nm.parent + # Avoid data refs as we have just processed them. + if isinstance(par, (Data_Ref, Part_Ref)): + continue + if nm.string not in local_map: + continue + replace_node(nm, local_map[nm.string]) + replace_node(assoc, rest) + + return ast + + +def assign_globally_unique_subprogram_names(ast: Program, keepers: Set[SPEC]) -> Program: + """ + Update the functions (and interchangeably, subroutines) to have globally unique names. + Precondition: + 1. All indirections are already removed from the program, except for the explicit renames. + 2. All public/private access statements were cleanly removed. + TODO: Make structure names unique too. + """ + SUFFIX, COUNTER = 'deconglobalfn', 0 + + ident_map = identifier_specs(ast) + alias_map = alias_specs(ast) + + # Make new unique names for the identifiers. + uident_map: Dict[SPEC, str] = {} + for k in ident_map.keys(): + if k in keepers: + continue + uname, COUNTER = f"{k[-1]}_{SUFFIX}_{COUNTER}", COUNTER + 1 + uident_map[k] = uname + + # PHASE 1.a: Remove all the places where any function is imported. + for use in walk(ast, Use_Stmt): + mod_name = singular(children_of_type(use, Name)).string + mod_spec = (mod_name,) + olist = atmost_one(children_of_type(use, Only_List)) + if not olist: + continue + survivors = [] + for c in olist.children: + assert isinstance(c, (Name, Rename)) + if isinstance(c, Name): + src, tgt = c, c + elif isinstance(c, Rename): + _, src, tgt = c.children + src, tgt = src.string, tgt.string + tgt_spec = find_real_ident_spec(tgt, mod_spec, alias_map) + assert tgt_spec in ident_map + if not isinstance(ident_map[tgt_spec], (Function_Stmt, Subroutine_Stmt)): + survivors.append(c) + if survivors: + olist.items = survivors + _reparent_children(olist) + else: + par = use.parent + par.content = [c for c in par.children if c != use] + _reparent_children(par) + + # PHASE 1.b: Replaces all the function callsites. + for fref in walk(ast, (Function_Reference, Call_Stmt)): + scope_spec = find_scope_spec(fref) + + # TODO: Add ref. + name, _ = fref.children + if not isinstance(name, Name): + # Intrinsics are not to be renamed. + assert isinstance(name, Intrinsic_Name), f"{fref}" + continue + fspec = find_real_ident_spec(name.string, scope_spec, alias_map) + assert fspec in ident_map + assert isinstance(ident_map[fspec], (Function_Stmt, Subroutine_Stmt)) + if fspec not in uident_map: + # We have chosen to not rename it. + continue + uname = uident_map[fspec] + ufspec = fspec[:-1] + (uname,) + name.string = uname + + # Find the nearest execution and its correpsonding specification parts. + execution_part = fref.parent + while not isinstance(execution_part, Execution_Part): + execution_part = execution_part.parent + subprog = execution_part.parent + specification_part = atmost_one(children_of_type(subprog, Specification_Part)) + + # Find out the module name. + cmod = fref.parent + while cmod and not isinstance(cmod, (Module, Main_Program)): + cmod = cmod.parent + if cmod: + stmt, _, _, _ = _get_module_or_program_parts(cmod) + cmod = singular(children_of_type(stmt, Name)).string.lower() + else: + subp = list(children_of_type(ast, Subroutine_Subprogram)) + assert subp + stmt = singular(children_of_type(subp[0], Subroutine_Stmt)) + cmod = singular(children_of_type(stmt, Name)).string.lower() + + assert 1 <= len(ufspec) + if len(ufspec) == 1: + # Nothing to do for the toplevel subprograms. They are already available. + continue + mod = ufspec[0] + if mod == cmod: + # Since this function is already defined at the current module, there is nothing to import. + continue + + if not specification_part: + append_children(subprog, Specification_Part(get_reader(f"use {mod}, only: {uname}"))) + else: + prepend_children(specification_part, Use_Stmt(f"use {mod}, only: {uname}")) + + # PHASE 1.d: Replaces actual function names. + for k, v in ident_map.items(): + if not isinstance(v, (Function_Stmt, Subroutine_Stmt)): + continue + if k not in uident_map: + # We have chosen to not rename it. + continue + oname, uname = k[-1], uident_map[k] + singular(children_of_type(v, Name)).string = uname + # Fix the tail too. + fdef = v.parent + end_stmt = singular(children_of_type(fdef, (End_Function_Stmt, End_Subroutine_Stmt))) + singular(children_of_type(end_stmt, Name)).string = uname + # For functions, the function name is also available as a variable inside. + if isinstance(v, Function_Stmt): + vspec = atmost_one(children_of_type(fdef, Specification_Part)) + vexec = atmost_one(children_of_type(fdef, Execution_Part)) + for nm in walk([n for n in [vspec, vexec] if n], Name): + if nm.string != oname: + continue + local_spec = search_local_alias_spec(nm) + # We need to do a bit of surgery, since we have the `oname` inide the scope ending with `uname`. + local_spec = local_spec[:-2] + local_spec[-1:] + local_spec = tuple(x.split('_deconglobalfn_')[0] for x in local_spec) + assert local_spec in ident_map and ident_map[local_spec] == v + nm.string = uname + + return ast + + +def add_use_to_specification(scdef: SCOPE_OBJECT_TYPES, clause: str): + specification_part = atmost_one(children_of_type(scdef, Specification_Part)) + if not specification_part: + append_children(scdef, Specification_Part(get_reader(clause))) + else: + prepend_children(specification_part, Use_Stmt(clause)) + + +def assign_globally_unique_variable_names(ast: Program, keepers: Set[str]) -> Program: + """ + Update the variable declarations to have globally unique names. + Precondition: + 1. All indirections are already removed from the program, except for the explicit renames. + 2. All public/private access statements were cleanly removed. + """ + SUFFIX, COUNTER = 'deconglobalvar', 0 + + ident_map = identifier_specs(ast) + alias_map = alias_specs(ast) + + # Make new unique names for the identifiers. + uident_map: Dict[SPEC, str] = {} + for k in ident_map.keys(): + if k[-1] in keepers: + continue + uname, COUNTER = f"{k[-1]}_{SUFFIX}_{COUNTER}", COUNTER + 1 + uident_map[k] = uname + + # PHASE 1.a: Remove all the places where any variable is imported. + for use in walk(ast, Use_Stmt): + mod_name = singular(children_of_type(use, Name)).string + mod_spec = (mod_name,) + olist = atmost_one(children_of_type(use, Only_List)) + if not olist: + continue + survivors = [] + for c in olist.children: + assert isinstance(c, (Name, Rename)) + if isinstance(c, Name): + src, tgt = c, c + elif isinstance(c, Rename): + _, src, tgt = c.children + src, tgt = src.string, tgt.string + tgt_spec = find_real_ident_spec(tgt, mod_spec, alias_map) + assert tgt_spec in ident_map + if not isinstance(ident_map[tgt_spec], Entity_Decl): + survivors.append(c) + if survivors: + olist.items = survivors + _reparent_children(olist) + else: + par = use.parent + par.content = [c for c in par.children if c != use] + _reparent_children(par) + + # PHASE 1.b: Replaces all the keywords when calling the functions. This must be done earlier than resolving other + # references, because otherwise we cannot distinguish the two `kw`s in `fn(kw=kw)`. + for kv in walk(ast, Actual_Arg_Spec): + fref = kv.parent.parent + if not isinstance(fref, (Function_Reference, Call_Stmt)): + # Not a user defined function, so we are not renaming its internal variables anyway. + continue + callee, _ = fref.children + if isinstance(callee, Intrinsic_Name): + # Not a user defined function, so we are not renaming its internal variables anyway. + continue + cspec = search_real_local_alias_spec(callee, alias_map) + cspec = ident_spec(alias_map[cspec]) + assert cspec + k, _ = kv.children + assert isinstance(k, Name) + kspec = find_real_ident_spec(k.string, cspec, alias_map) + if kspec not in uident_map: + # If we haven't planned to rename it, then skip. + continue + k.string = uident_map[kspec] + + # PHASE 1.c: Replaces all the direct references. + for vref in walk(ast, Name): + if isinstance(vref.parent, Entity_Decl): + # Do not change the variable declarations themselves just yet. + continue + vspec = search_real_local_alias_spec(vref, alias_map) + if not vspec: + # It was not a valid alias (e.g., a sturcture component). + continue + if not isinstance(alias_map[vspec], Entity_Decl): + # Does not refer to a variable. + continue + edcl = alias_map[vspec] + fdef = find_scope_ancestor(edcl) + if isinstance(fdef, Function_Subprogram) and find_name_of_node(fdef) == find_name_of_node(edcl): + # Function return variables must retain their names. + continue + + scope_spec = find_scope_spec(vref) + vspec = find_real_ident_spec(vspec[-1], scope_spec, alias_map) + assert vspec in ident_map + if vspec not in uident_map: + # We have chosen to not rename it. + continue + uname = uident_map[vspec] + vref.string = uname + + if len(vspec) > 2: + # If the variable is not defined in a toplevel object, so we're done. + continue + assert len(vspec) == 2 + mod, _ = vspec + if not isinstance(alias_map[(mod,)], Module_Stmt): + # We can only import modules. + continue + + # Find the nearest specification part (or lack thereof). + scdef = alias_map[scope_spec].parent + # Find out the current module name. + cmod = scdef + while not isinstance(cmod.parent, Program): + cmod = cmod.parent + if find_name_of_node(cmod) == mod: + # Since this variable is already defined at the current module, there is nothing to import. + continue + add_use_to_specification(scdef, f"use {mod}, only: {uname}") + + # PHASE 1.d: Replaces all the literals where a variable can be used as a "kind". + for lit in walk(ast, Real_Literal_Constant): + val, kind = lit.children + if not kind: + continue + # Strangely, we get a plain `str` instead of a `Name`. + assert isinstance(kind, str) + scope_spec = find_scope_spec(lit) + kind_spec = search_real_ident_spec(kind, scope_spec, alias_map) + if not kind_spec or kind_spec not in uident_map: + continue + uname = uident_map[kind_spec] + lit.items = (val, uname) + + if len(kind_spec) > 2: + # If the variable is not defined in a toplevel object, so we're done. + continue + assert len(kind_spec) == 2 + mod, _ = kind_spec + if not isinstance(alias_map[(mod,)], Module_Stmt): + # We can only import modules. + continue + + # Find the nearest specification part (or lack thereof). + scdef = alias_map[scope_spec].parent + # Find out the current module name. + cmod = scdef + while not isinstance(cmod.parent, Program): + cmod = cmod.parent + if find_name_of_node(cmod) == mod: + # Since this variable is already defined at the current module, there is nothing to import. + continue + add_use_to_specification(scdef, f"use {mod}, only: {uname}") + + # PHASE 1.e: Replaces actual variable names. + for k, v in ident_map.items(): + if not isinstance(v, Entity_Decl): + continue + if k not in uident_map: + # We have chosen to not rename it. + continue + oname, uname = k[-1], uident_map[k] + fdef = find_scope_ancestor(v) + if isinstance(fdef, Function_Subprogram) and find_name_of_node(fdef) == oname: + # Function return variables must retain their names. + continue + singular(children_of_type(v, Name)).string = uname + + return ast + + +def _get_module_or_program_parts(mod: Union[Module, Main_Program]) \ + -> Tuple[ + Union[Module_Stmt, Program_Stmt], + Optional[Specification_Part], + Optional[Execution_Part], + Optional[Module_Subprogram_Part], + ]: + # There must exist a module statment. + stmt = singular(children_of_type(mod, Module_Stmt if isinstance(mod, Module) else Program_Stmt)) + # There may or may not exist a specification part. + spec = list(children_of_type(mod, Specification_Part)) + assert len(spec) <= 1, f"A module/program cannot have more than one specification parts, found {spec} in {mod}" + spec = spec[0] if spec else None + # There may or may not exist an execution part. + exec = list(children_of_type(mod, Execution_Part)) + assert len(exec) <= 1, f"A module/program cannot have more than one execution parts, found {spec} in {mod}" + exec = exec[0] if exec else None + # There may or may not exist a subprogram part. + subp = list(children_of_type(mod, Module_Subprogram_Part)) + assert len(subp) <= 1, f"A module/program cannot have more than one subprogram parts, found {subp} in {mod}" + subp = subp[0] if subp else None + return stmt, spec, exec, subp + + +def consolidate_uses(ast: Program) -> Program: + for sp in reversed(walk(ast, Specification_Part)): + all_use: Set[str] = set() + use_map: Dict[str, Set[str]] = {} + # Build the table. + for u in children_of_type(sp, Use_Stmt): + name = singular(children_of_type(u, Name)).string + olist = atmost_one(children_of_type(u, Only_List)) + if not olist: + all_use.add(name) + else: + if name not in use_map: + use_map[name] = set() + use_map[name].update(c.tofortran() for c in olist.children) + # Build new use statements. + nuses: List[Use_Stmt] = [ + Use_Stmt(f"use {k}") if k in all_use else Use_Stmt(f"use {k}, only: {', '.join(use_map[k])}") + for k in use_map.keys() | all_use] + reuses: List[Use_Stmt] = [ + Use_Stmt(f"use {k}, only: {', '.join(r for r in use_map[k] if '=>' in r)}") + for k in use_map.keys() if any('=>' in r for r in use_map[k])] + # Remove the old ones, and prepend the new ones. + sp.content = nuses + reuses + [c for c in sp.children if not isinstance(c, Use_Stmt)] + _reparent_children(sp) + return ast + + +def _prune_branches_in_ifblock(ib: If_Construct, alias_map: SPEC_TABLE): + ifthen = ib.children[0] + assert isinstance(ifthen, If_Then_Stmt) + cond, = ifthen.children + cval = _const_eval_basic_type(cond, alias_map) + if cval is None: + return + assert isinstance(cval, np.bool_) + + elifat = [idx for idx, c in enumerate(ib.children) if isinstance(c, (Else_If_Stmt, Else_Stmt))] + if cval: + cut = elifat[0] if elifat else -1 + actions = ib.children[1:cut] + replace_node(ib, actions) + return + elif not elifat: + remove_self(ib) + return + + cut = elifat[0] + cut_cond = ib.children[cut] + if isinstance(cut_cond, Else_Stmt): + actions = ib.children[cut + 1:-1] + replace_node(ib, actions) + return + + isinstance(cut_cond, Else_If_Stmt) + cut_cond, _ = cut_cond.children + remove_children(ib, ib.children[1:(cut + 1)]) + set_children(ifthen, (cut_cond,)) + _prune_branches_in_ifblock(ib, alias_map) + + +def prune_branches(ast: Program) -> Program: + alias_map = alias_specs(ast) + for ib in walk(ast, If_Construct): + _prune_branches_in_ifblock(ib, alias_map) + return ast + + +LITERAL_TYPES = Union[ + Real_Literal_Constant, Signed_Real_Literal_Constant, Int_Literal_Constant, Signed_Int_Literal_Constant, + Logical_Literal_Constant] + + +def numpy_type_to_literal(val: NUMPY_TYPES) -> Union[LITERAL_TYPES]: + if isinstance(val, np.bool_): + val = Logical_Literal_Constant('.true.' if val else '.false.') + elif isinstance(val, NUMPY_INTS): + bytez = _count_bytes(type(val)) + if val < 0: + val = Signed_Int_Literal_Constant(f"{val}" if bytez == 4 else f"{val}_{bytez}") + else: + val = Int_Literal_Constant(f"{val}" if bytez == 4 else f"{val}_{bytez}") + elif isinstance(val, NUMPY_REALS): + bytez = _count_bytes(type(val)) + valstr = str(val) + if bytez == 8: + if 'e' in valstr: + valstr = valstr.replace('e', 'D') + else: + valstr = f"{valstr}D0" + if val < 0: + val = Signed_Real_Literal_Constant(valstr) + else: + val = Real_Literal_Constant(valstr) + return val + + +def const_eval_nodes(ast: Program) -> Program: + EXPRESSION_TYPES = Union[ + LITERAL_TYPES, Expr, Add_Operand, Mult_Operand, Level_2_Expr, Level_3_Expr, Level_4_Expr, Level_5_Expr, + Intrinsic_Function_Reference] + + alias_map = alias_specs(ast) + + def _const_eval_node(n: Base) -> bool: + val = _const_eval_basic_type(n, alias_map) + if val is None: + return False + assert not np.isnan(val) + val = numpy_type_to_literal(val) + replace_node(n, val) + return True + + for asgn in reversed(walk(ast, Assignment_Stmt)): + lv, op, rv = asgn.children + assert op == '=' + _const_eval_node(rv) + for expr in reversed(walk(ast, EXPRESSION_TYPES)): + # Try to const-eval the expression. + if _const_eval_node(expr): + # If the node is successfully replaced, then nothing else to do. + continue + # Otherwise, try to at least replace the names with the literal values. + for nm in reversed(walk(expr, Name)): + _const_eval_node(nm) + for knode in reversed(walk(ast, Kind_Selector)): + _, kind, _ = knode.children + _const_eval_node(kind) + + NON_EXPRESSION_TYPES = Union[ + Explicit_Shape_Spec, Loop_Control, Call_Stmt, Function_Reference, Initialization, Component_Initialization] + for node in reversed(walk(ast, NON_EXPRESSION_TYPES)): + for nm in reversed(walk(node, Name)): + _const_eval_node(nm) + + return ast + + +def lower_identifier_names(ast: Program) -> Program: + for nm in walk(ast, Name): + nm.string = nm.string.lower() + return ast diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index d1e68572de..2797e05d9d 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -1,5 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload +from typing import List, Optional, Tuple, Union, Dict, Any + # The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields. # Attributes are not used when walking the tree, but are useful for debugging and for code generation. @@ -7,58 +8,73 @@ class FNode(object): - def __init__(self, *args, **kwargs): # real signature unknown - self.integrity_exceptions = [] - self.read_vars = [] - self.written_vars = [] - self.parent: Optional[ - Union[ - Subroutine_Subprogram_Node, - Function_Subprogram_Node, - Main_Program_Node, - Module_Node - ] + def __init__(self, line_number: int = -1, **kwargs): # real signature unknown + self.line_number = line_number + self.parent: Union[ + None, + Subroutine_Subprogram_Node, + Function_Subprogram_Node, + Main_Program_Node, + Module_Node ] = None for k, v in kwargs.items(): setattr(self, k, v) - _attributes = ("line_number", ) - _fields = () - integrity_exceptions: List - read_vars: List - written_vars: List + _attributes: Tuple[str, ...] = ("line_number",) + _fields: Tuple[str, ...] = () def __eq__(self, o: object) -> bool: - if type(self) is type(o): - # check that all fields and attributes match - self_field_vals = list(map(lambda name: getattr(self, name, None), self._fields)) - self_attr_vals = list(map(lambda name: getattr(self, name, None), self._attributes)) - o_field_vals = list(map(lambda name: getattr(o, name, None), o._fields)) - o_attr_vals = list(map(lambda name: getattr(o, name, None), o._attributes)) - - return self_field_vals == o_field_vals and self_attr_vals == o_attr_vals - return False + if not isinstance(o, type(self)): + return False + # check that all fields and attributes match + self_field_vals = list(map(lambda name: getattr(self, name, None), self._fields)) + self_attr_vals = list(map(lambda name: getattr(self, name, None), self._attributes)) + o_field_vals = list(map(lambda name: getattr(o, name, None), o._fields)) + o_attr_vals = list(map(lambda name: getattr(o, name, None), o._attributes)) + return self_field_vals == o_field_vals and self_attr_vals == o_attr_vals class Program_Node(FNode): + def __init__(self, + main_program: 'Main_Program_Node', + function_definitions: List, + subroutine_definitions: List, + modules: List, + module_declarations: Dict, + placeholders: Optional[List] = None, + placeholders_offsets: Optional[List] = None, + structures: Optional['Structures'] = None, + **kwargs): + super().__init__(**kwargs) + self.main_program = main_program + self.function_definitions = function_definitions + self.subroutine_definitions = subroutine_definitions + self.modules = modules + self.module_declarations = module_declarations + self.structures = structures + self.placeholders = placeholders + self.placeholders_offsets = placeholders_offsets + _attributes = () _fields = ( - "main_program", - "function_definitions", - "subroutine_definitions", - "modules", + 'main_program', + 'function_definitions', + 'subroutine_definitions', + 'modules', ) class BinOp_Node(FNode): - _attributes = ( - 'op', - 'type', - ) - _fields = ( - 'lval', - 'rval', - ) + def __init__(self, op: str, lval: FNode, rval: FNode, type: str = 'VOID', **kwargs): + super().__init__(**kwargs) + assert rval + self.op = op + self.lval = lval + self.rval = rval + self.type = type + + _attributes = ('op', 'type') + _fields = ('lval', 'rval') class UnOp_Node(FNode): @@ -67,25 +83,102 @@ class UnOp_Node(FNode): 'postfix', 'type', ) - _fields = ('lval', ) + _fields = ('lval',) + + +class Exit_Node(FNode): + _attributes = () + _fields = () class Main_Program_Node(FNode): - _attributes = ("name", ) + _attributes = ("name",) _fields = ("execution_part", "specification_part") class Module_Node(FNode): - _attributes = ('name', ) + def __init__(self, + name: 'Name_Node', + specification_part: 'Specification_Part_Node', + subroutine_definitions: List['Subroutine_Subprogram_Node'], + function_definitions: List['Function_Subprogram_Node'], + interface_blocks: Dict, + **kwargs): + super().__init__(**kwargs) + self.name = name + self.specification_part = specification_part + self.subroutine_definitions = subroutine_definitions + self.function_definitions = function_definitions + self.interface_blocks = interface_blocks + + _attributes = ('name',) _fields = ( 'specification_part', 'subroutine_definitions', 'function_definitions', + 'interface_blocks' + ) + + +class Module_Subprogram_Part_Node(FNode): + def __init__(self, + subroutine_definitions: List['Subroutine_Subprogram_Node'], + function_definitions: List['Function_Subprogram_Node'], + **kwargs): + super().__init__(**kwargs) + self.subroutine_definitions = subroutine_definitions + self.function_definitions = function_definitions + + _attributes = () + _fields = ( + 'subroutine_definitions', + 'function_definitions', + ) + + +class Internal_Subprogram_Part_Node(FNode): + def __init__(self, + subroutine_definitions: List['Subroutine_Subprogram_Node'], + function_definitions: List['Function_Subprogram_Node'], + **kwargs): + super().__init__(**kwargs) + self.subroutine_definitions = subroutine_definitions + self.function_definitions = function_definitions + + _attributes = () + _fields = ( + 'subroutine_definitions', + 'function_definitions', + ) + + +class Actual_Arg_Spec_Node(FNode): + _fields = ( + 'arg_name', + 'arg', ) class Function_Subprogram_Node(FNode): - _attributes = ('name', 'type', 'ret_name') + def __init__(self, + name: 'Name_Node', + args: List, + ret: 'Name_Node', + specification_part: 'Specification_Part_Node', + execution_part: 'Execution_Part_Node', + type: str, + elemental: bool, + **kwargs): + super().__init__(**kwargs) + self.name = name + self.type = type + self.ret = ret + self.args = args + self.specification_part = specification_part + self.execution_part = execution_part + self.elemental = elemental + + _attributes = ('name', 'type', 'ret') _fields = ( 'args', 'specification_part', @@ -94,68 +187,159 @@ class Function_Subprogram_Node(FNode): class Subroutine_Subprogram_Node(FNode): - _attributes = ('name', 'type') + def __init__(self, + name: 'Name_Node', + args: List, + specification_part: 'Specification_Part_Node', + execution_part: 'Execution_Part_Node', + mandatory_args_count: int = -1, + optional_args_count: int = -1, + type: Any = None, + elemental: bool = False, + **kwargs): + super().__init__(**kwargs) + self.name = name + self.type = type + self.args = args + self.mandatory_args_count = mandatory_args_count + self.optional_args_count = optional_args_count + self.elemental = elemental + self.specification_part = specification_part + self.execution_part = execution_part + + _attributes = ('name', 'type', 'elemental') _fields = ( 'args', + 'mandatory_args_count', + 'optional_args_count', 'specification_part', 'execution_part', ) -class Module_Stmt_Node(FNode): - _attributes = ('name', ) +class Interface_Block_Node(FNode): + _attributes = ('name',) + _fields = ( + 'subroutines', + ) + + +class Interface_Stmt_Node(FNode): + _attributes = () _fields = () +class Procedure_Name_List_Node(FNode): + _attributes = () + _fields = ('subroutines',) + + +class Procedure_Statement_Node(FNode): + _attributes = () + _fields = ('namelists',) + + +class Module_Stmt_Node(FNode): + _attributes = () + _fields = ('functions',) + + class Program_Stmt_Node(FNode): - _attributes = ('name', ) + _attributes = ('name',) _fields = () class Subroutine_Stmt_Node(FNode): - _attributes = ('name', ) - _fields = ('args', ) + _attributes = ('name',) + _fields = ('args',) class Function_Stmt_Node(FNode): - _attributes = ('name', ) - _fields = ('args', 'return') + def __init__(self, name: 'Name_Node', args: List[FNode], ret: Optional['Suffix_Node'], elemental: bool, type: str, + **kwargs): + super().__init__(**kwargs) + self.name = name + self.args = args + self.ret = ret + self.elemental = elemental + self.type = type + + _attributes = ('name', 'elemental', 'type') + _fields = ('args', 'ret',) + + +class Prefix_Node(FNode): + def __init__(self, type: str, elemental: bool, recursive: bool, pure: bool, **kwargs): + super().__init__(**kwargs) + self.type = type + self.elemental = elemental + self.recursive = recursive + self.pure = pure + + _attributes = ('elemental', 'recursive', 'pure',) + _fields = () class Name_Node(FNode): - _attributes = ('name', 'type') + def __init__(self, name: str, type: str = 'VOID', **kwargs): + super().__init__(**kwargs) + self.name = name + self.type = type + + _attributes = ('name', 'type',) _fields = () class Name_Range_Node(FNode): - _attributes = ('name', 'type', 'arrname', 'pos') + _attributes = ('name', 'type', 'arrname', 'pos',) _fields = () +class Where_Construct_Node(FNode): + _attributes = () + _fields = ('main_body', 'main_cond', 'else_body', 'elifs_body', 'elifs_cond',) + + class Type_Name_Node(FNode): - _attributes = ('name', 'type') + _attributes = ('name', 'type',) _fields = () +class Generic_Binding_Node(FNode): + _attributes = () + _fields = ('name', 'binding',) + + class Specification_Part_Node(FNode): - _fields = ('specifications', 'symbols', 'typedecls') + _fields = ('specifications', 'symbols', 'interface_blocks', 'typedecls', 'enums',) + + +class Stop_Stmt_Node(FNode): + _attributes = ('code',) + + +class Error_Stmt_Node(FNode): + _fields = ('error',) class Execution_Part_Node(FNode): - _fields = ('execution', ) + _fields = ('execution',) class Statement_Node(FNode): - _attributes = ('col_offset', ) + _attributes = ('col_offset',) _fields = () class Array_Subscript_Node(FNode): - _attributes = ( - 'name', - 'type', - ) - _fields = ('indices', ) + def __init__(self, name: Name_Node, type: str, indices: List[FNode], **kwargs): + super().__init__(**kwargs) + self.name = name + self.type = type + self.indices = indices + + _attributes = ('type',) + _fields = ('name', 'indices',) class Type_Decl_Node(Statement_Node): @@ -168,22 +352,27 @@ class Type_Decl_Node(Statement_Node): class Allocate_Shape_Spec_Node(FNode): _attributes = () - _fields = ('sizes', ) + _fields = ('sizes',) class Allocate_Shape_Spec_List(FNode): _attributes = () - _fields = ('shape_list', ) + _fields = ('shape_list',) class Allocation_Node(FNode): - _attributes = ('name', ) - _fields = ('shape', ) + _attributes = ('name',) + _fields = ('shape',) + + +class Continue_Node(FNode): + _attributes = () + _fields = () class Allocate_Stmt_Node(FNode): _attributes = () - _fields = ('allocation_list', ) + _fields = ('allocation_list',) class Symbol_Decl_Node(Statement_Node): @@ -207,38 +396,58 @@ class Symbol_Array_Decl_Node(Statement_Node): ) _fields = ( 'sizes', - 'offsets' + 'offsets', 'typeref', 'init', ) class Var_Decl_Node(Statement_Node): - _attributes = ( - 'name', - 'type', - 'alloc', - 'kind', - ) - _fields = ( - 'sizes', - 'offsets', - 'typeref', - 'init', - ) + def __init__(self, name: str, type: str, + alloc: Optional[bool] = None, optional: Optional[bool] = None, + sizes: Optional[List] = None, offsets: Optional[List] = None, + init: Optional[FNode] = None, actual_offsets: Optional[List] = None, + typeref: Optional[Any] = None, kind: Optional[Any] = None, + **kwargs): + super().__init__(**kwargs) + self.name = name + self.type = type + self.alloc = alloc + self.kind = kind + self.optional = optional + self.sizes = sizes + self.offsets = offsets + self.actual_offsets = actual_offsets + self.typeref = typeref + self.init = init + + _attributes = ('name', 'type', 'alloc', 'kind', 'optional') + _fields = ('sizes', 'offsets', 'actual_offsets', 'typeref', 'init') class Arg_List_Node(FNode): - _fields = ('args', ) + _fields = ('args',) class Component_Spec_List_Node(FNode): - _fields = ('args', ) + _fields = ('args',) + + +class Allocate_Object_List_Node(FNode): + _fields = ('list',) + + +class Deallocate_Stmt_Node(FNode): + _fields = ('list',) class Decl_Stmt_Node(Statement_Node): + def __init__(self, vardecl: List[Var_Decl_Node], **kwargs): + super().__init__(**kwargs) + self.vardecl = vardecl + _attributes = () - _fields = ('vardecl', ) + _fields = ('vardecl',) class VarType: @@ -250,55 +459,103 @@ class Void(VarType): class Literal(FNode): - _attributes = ('value', ) + def __init__(self, value: str, type: str, **kwargs): + super().__init__(**kwargs) + self.value = value + self.type = type + + _attributes = ('value', 'type') _fields = () class Int_Literal_Node(Literal): - _attributes = () - _fields = () + def __init__(self, value: str, type='INTEGER', **kwargs): + super().__init__(value, type, **kwargs) class Real_Literal_Node(Literal): - _attributes = () - _fields = () + def __init__(self, value: str, type='REAL', **kwargs): + super().__init__(value, type, **kwargs) -class Bool_Literal_Node(Literal): - _attributes = () - _fields = () +class Double_Literal_Node(Literal): + def __init__(self, value: str, type='DOUBLE', **kwargs): + super().__init__(value, type, **kwargs) -class String_Literal_Node(Literal): - _attributes = () - _fields = () +class Bool_Literal_Node(Literal): + def __init__(self, value: str, type='LOGICAL', **kwargs): + super().__init__(value, type, **kwargs) class Char_Literal_Node(Literal): + def __init__(self, value: str, type='CHAR', **kwargs): + super().__init__(value, type, **kwargs) + + +class Suffix_Node(FNode): + def __init__(self, name: 'Name_Node', **kwargs): + super().__init__(**kwargs) + self.name = name + _attributes = () - _fields = () + _fields = ('name',) class Call_Expr_Node(FNode): + def __init__(self, name: 'Name_Node', args: List[FNode], subroutine: bool, type: str, **kwargs): + super().__init__(**kwargs) + self.name = name + self.args = args + self.subroutine = subroutine + self.type = type + _attributes = ('type', 'subroutine') - _fields = ( - 'name', - 'args', - ) + _fields = ('name', 'args') + + +class Derived_Type_Stmt_Node(FNode): + _attributes = ('name',) + _fields = ('args',) + + +class Derived_Type_Def_Node(FNode): + _attributes = ('name',) + _fields = ('component_part', 'procedure_part') + + +class Component_Part_Node(FNode): + _attributes = () + _fields = ('component_def_stmts',) + + +class Data_Component_Def_Stmt_Node(FNode): + _attributes = () + _fields = ('vars',) + + +class Data_Ref_Node(FNode): + _attributes = () + _fields = ('parent_ref', 'part_ref') class Array_Constructor_Node(FNode): _attributes = () - _fields = ('value_list', ) + _fields = ('value_list',) class Ac_Value_List_Node(FNode): _attributes = () - _fields = ('value_list', ) + _fields = ('value_list',) class Section_Subscript_List_Node(FNode): - _fields = ('list') + _fields = ('list',) + + +class Pointer_Assignment_Stmt_Node(FNode): + _attributes = () + _fields = ('name_pointer', 'name_target') class For_Stmt_Node(FNode): @@ -330,14 +587,135 @@ class If_Stmt_Node(FNode): ) +class Defer_Shape_Node(FNode): + _attributes = () + _fields = () + + +class Component_Initialization_Node(FNode): + _attributes = () + _fields = ('init',) + + +class Case_Cond_Node(FNode): + _fields = ('cond', 'op') + _attributes = () + + class Else_Separator_Node(FNode): _attributes = () _fields = () +class Procedure_Separator_Node(FNode): + _attributes = () + _fields = ('parent_ref', 'part_ref') + + +class Pointer_Object_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Read_Stmt_Node(FNode): + _attributes = () + _fields = ('args',) + + +class Close_Stmt_Node(FNode): + _attributes = () + _fields = ('args',) + + +class Open_Stmt_Node(FNode): + _attributes = () + _fields = ('args',) + + +class Associate_Stmt_Node(FNode): + _attributes = () + _fields = ('args',) + + +class Associate_Construct_Node(FNode): + _attributes = () + _fields = ('associate', 'body') + + +class Association_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Association_Node(FNode): + _attributes = () + _fields = ('name', 'expr') + + +class Connect_Spec_Node(FNode): + _attributes = ('type',) + _fields = ('args',) + + +class Close_Spec_Node(FNode): + _attributes = ('type',) + _fields = ('args',) + + +class Close_Spec_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class IO_Control_Spec_Node(FNode): + _attributes = ('type',) + _fields = ('args',) + + +class IO_Control_Spec_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Connect_Spec_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Nullify_Stmt_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Namelist_Stmt_Node(FNode): + _attributes = () + _fields = ('list', 'name') + + +class Namelist_Group_Object_List_Node(FNode): + _attributes = () + _fields = ('list',) + + +class Bound_Procedures_Node(FNode): + _attributes = () + _fields = ('procedures',) + + +class Specific_Binding_Node(FNode): + _attributes = () + _fields = ('name', 'args') + + class Parenthesis_Expr_Node(FNode): + def __init__(self, expr: FNode, **kwargs): + super().__init__(**kwargs) + assert hasattr(expr, 'type') + self.expr = expr + self.type = expr.type + _attributes = () - _fields = ('expr', ) + _fields = ('expr', 'type') class Nonlabel_Do_Stmt_Node(FNode): @@ -349,6 +727,28 @@ class Nonlabel_Do_Stmt_Node(FNode): ) +class While_True_Control(FNode): + _attributes = () + _fields = ( + 'name', + ) + + +class While_Control(FNode): + _attributes = () + _fields = ( + 'cond', + ) + + +class While_Stmt_Node(FNode): + _attributes = ('name') + _fields = ( + 'body', + 'cond', + ) + + class Loop_Control_Node(FNode): _attributes = () _fields = ( @@ -360,32 +760,38 @@ class Loop_Control_Node(FNode): class Else_If_Stmt_Node(FNode): _attributes = () - _fields = ('cond', ) + _fields = ('cond',) class Only_List_Node(FNode): _attributes = () - _fields = ('names', ) + _fields = ('names', 'renames',) + + +class Rename_Node(FNode): + _attributes = () + _fields = ('oldname', 'newname',) class ParDecl_Node(FNode): - _attributes = ('type', ) - _fields = ('range', ) + _attributes = ('type',) + _fields = ('range',) class Structure_Constructor_Node(FNode): - _attributes = ('type', ) + _attributes = ('type',) _fields = ('name', 'args') class Use_Stmt_Node(FNode): - _attributes = ('name', ) - _fields = ('list', ) + _attributes = ('name', 'list_all') + _fields = ('list',) class Write_Stmt_Node(FNode): _attributes = () - _fields = ('args', ) + _fields = ('args',) + class Break_Node(FNode): _attributes = () diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 57508d6d90..db5769a977 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -1,8 +1,94 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. -from dace.frontend.fortran import ast_components, ast_internal_classes -from typing import Dict, List, Optional, Tuple, Set import copy +from typing import Dict, List, Optional, Tuple, Set, Union + +import sympy as sp + +from dace import symbolic as sym +from dace.frontend.fortran import ast_internal_classes, ast_utils + + +class Structure: + + def __init__(self, name: str): + self.vars: Dict[str, Union[ast_internal_classes.Symbol_Decl_Node, ast_internal_classes.Var_Decl_Node]] = {} + self.name = name + + +class Structures: + + def __init__(self, definitions: List[ast_internal_classes.Derived_Type_Def_Node]): + self.structures: Dict[str, Structure] = {} + self.parse(definitions) + + def parse(self, definitions: List[ast_internal_classes.Derived_Type_Def_Node]): + + for structure in definitions: + + struct = Structure(name=structure.name.name) + if structure.component_part is not None: + if structure.component_part.component_def_stmts is not None: + for statement in structure.component_part.component_def_stmts: + if isinstance(statement, ast_internal_classes.Data_Component_Def_Stmt_Node): + for var in statement.vars.vardecl: + struct.vars[var.name] = var + + self.structures[structure.name.name] = struct + + def is_struct(self, type_name: str): + return type_name in self.structures + + def get_definition(self, type_name: str): + return self.structures[type_name] + + def find_definition(self, scope_vars, node: ast_internal_classes.Data_Ref_Node, + variable_name: Optional[ast_internal_classes.Name_Node] = None): + + # we assume starting from the top (left-most) data_ref_node + # for struct1 % struct2 % struct3 % var + # we find definition of struct1, then we iterate until we find the var + + # find the top structure + top_ref = node + while isinstance(top_ref.parent_ref, ast_internal_classes.Data_Ref_Node): + top_ref = top_ref.parent_ref + + struct_type = scope_vars.get_var(node.parent, ast_utils.get_name(top_ref.parent_ref)).type + struct_def = self.structures[struct_type] + + # cur_node = node + cur_node = top_ref + + while True: + cur_node = cur_node.part_ref + + if isinstance(cur_node, ast_internal_classes.Array_Subscript_Node): + struct_def = self.structures[struct_type] + cur_var = struct_def.vars[cur_node.name.name] + node = cur_node + break + + elif isinstance(cur_node, ast_internal_classes.Name_Node): + struct_def = self.structures[struct_type] + cur_var = struct_def.vars[cur_node.name] + break + + if isinstance(cur_node.parent_ref.name, ast_internal_classes.Name_Node): + + if variable_name is not None and cur_node.parent_ref.name.name == variable_name.name: + return struct_def, struct_def.vars[cur_node.parent_ref.name.name] + + struct_type = struct_def.vars[cur_node.parent_ref.name.name].type + else: + + if variable_name is not None and cur_node.parent_ref.name == variable_name.name: + return struct_def, struct_def.vars[cur_node.parent_ref.name] + + struct_type = struct_def.vars[cur_node.parent_ref.name].type + struct_def = self.structures[struct_type] + + return struct_def, cur_var def iter_fields(node: ast_internal_classes.FNode): @@ -10,8 +96,6 @@ def iter_fields(node: ast_internal_classes.FNode): Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` that is present on *node*. """ - if not hasattr(node, "_fields"): - a = 1 for field in node._fields: try: yield field, getattr(node, field) @@ -19,6 +103,18 @@ def iter_fields(node: ast_internal_classes.FNode): pass +def iter_attributes(node: ast_internal_classes.FNode): + """ + Yield a tuple of ``(fieldname, value)`` for each field in ``node._attributes`` + that is present on *node*. + """ + for field in node._attributes: + try: + yield field, getattr(node, field) + except AttributeError: + pass + + def iter_child_nodes(node: ast_internal_classes.FNode): """ Yield all direct child nodes of *node*, that is, all fields that are nodes @@ -26,7 +122,7 @@ def iter_child_nodes(node: ast_internal_classes.FNode): """ for name, field in iter_fields(node): - #print("NASME:",name) + # print("NASME:",name) if isinstance(field, ast_internal_classes.FNode): yield field elif isinstance(field, list): @@ -42,6 +138,7 @@ class NodeVisitor(object): XXX is the class name you want to visit with these methods. """ + def visit(self, node: ast_internal_classes.FNode): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) @@ -65,6 +162,7 @@ class NodeTransformer(NodeVisitor): The `NodeTransformer` will walk the AST and use the return value of the visitor methods to replace old nodes. """ + def as_list(self, x): if isinstance(x, list): return x @@ -95,19 +193,130 @@ def generic_visit(self, node: ast_internal_classes.FNode): return node +class Flatten_Classes(NodeTransformer): + + def __init__(self, classes: List[ast_internal_classes.Derived_Type_Def_Node]): + self.classes = classes + self.current_class = None + + def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node): + self.current_class = node + return_node = self.generic_visit(node) + # self.current_class=None + return return_node + + def visit_Module_Node(self, node: ast_internal_classes.Module_Node): + self.current_class = None + return self.generic_visit(node) + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + new_node = self.generic_visit(node) + print("Subroutine: ", node.name.name) + if self.current_class is not None: + for i in self.classes: + if i.is_class is True: + if i.name.name == self.current_class.name.name: + for j in i.procedure_part.procedures: + if j.name.name == node.name.name: + return ast_internal_classes.Subroutine_Subprogram_Node( + name=ast_internal_classes.Name_Node(name=i.name.name + "_" + node.name.name, + type=node.type), + args=new_node.args, + specification_part=new_node.specification_part, + execution_part=new_node.execution_part, + mandatory_args_count=new_node.mandatory_args_count, + optional_args_count=new_node.optional_args_count, + elemental=new_node.elemental, + line_number=new_node.line_number) + elif hasattr(j, "args") and j.args[2] is not None: + if j.args[2].name == node.name.name: + return ast_internal_classes.Subroutine_Subprogram_Node( + name=ast_internal_classes.Name_Node(name=i.name.name + "_" + j.name.name, + type=node.type), + args=new_node.args, + specification_part=new_node.specification_part, + execution_part=new_node.execution_part, + mandatory_args_count=new_node.mandatory_args_count, + optional_args_count=new_node.optional_args_count, + elemental=new_node.elemental, + line_number=new_node.line_number) + return new_node + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + if self.current_class is not None: + for i in self.classes: + if i.is_class is True: + if i.name.name == self.current_class.name.name: + for j in i.procedure_part.procedures: + if j.name.name == node.name.name: + return ast_internal_classes.Call_Expr_Node( + name=ast_internal_classes.Name_Node(name=i.name.name + "_" + node.name.name, + type=node.type, args=node.args, + line_number=node.line_number), args=node.args, + type=node.type, subroutine=node.subroutine, line_number=node.line_number,parent=node.parent) + return self.generic_visit(node) + + class FindFunctionAndSubroutines(NodeVisitor): """ Finds all function and subroutine names in the AST :return: List of names """ + def __init__(self): - self.nodes: List[ast_internal_classes.Name_Node] = [] + self.names: List[ast_internal_classes.Name_Node] = [] + self.module_based_names: Dict[str, List[ast_internal_classes.Name_Node]] = {} + self.nodes: Dict[str, ast_internal_classes.FNode] = {} + self.iblocks: Dict[str, List[str]] = {} + self.current_module = "_dace_default" + self.module_based_names[self.current_module] = [] def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): - self.nodes.append(node.name) + ret = node.name + ret.elemental = node.elemental + self.names.append(ret) + self.nodes[ret.name] = node + self.module_based_names[self.current_module].append(ret) def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Subprogram_Node): - self.nodes.append(node.name) + ret = node.name + ret.elemental = node.elemental + self.names.append(ret) + self.nodes[ret.name] = node + self.module_based_names[self.current_module].append(ret) + + def visit_Module_Node(self, node: ast_internal_classes.Module_Node): + self.iblocks.update(node.interface_blocks) + self.current_module = node.name.name + self.module_based_names[self.current_module] = [] + self.generic_visit(node) + + @staticmethod + def from_node(node: ast_internal_classes.FNode) -> 'FindFunctionAndSubroutines': + v = FindFunctionAndSubroutines() + v.visit(node) + return v + + +class FindNames(NodeVisitor): + def __init__(self): + self.names: List[str] = [] + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + self.names.append(node.name) + + def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): + self.names.append(node.name.name) + for i in node.indices: + self.visit(i) + + +class FindDefinedNames(NodeVisitor): + def __init__(self): + self.names: List[str] = [] + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + self.names.append(node.name) class FindInputs(NodeVisitor): @@ -115,7 +324,9 @@ class FindInputs(NodeVisitor): Finds all inputs (reads) in the AST node and its children :return: List of names """ + def __init__(self): + self.nodes: List[ast_internal_classes.Name_Node] = [] def visit_Name_Node(self, node: ast_internal_classes.Name_Node): @@ -126,6 +337,30 @@ def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_ for i in node.indices: self.visit(i) + def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + if isinstance(node.parent_ref, ast_internal_classes.Name_Node): + self.nodes.append(node.parent_ref) + elif isinstance(node.parent_ref, ast_internal_classes.Array_Subscript_Node): + self.nodes.append(node.parent_ref.name) + if isinstance(node.parent_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.parent_ref.indices: + self.visit(i) + if isinstance(node.part_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.part_ref.indices: + self.visit(i) + elif isinstance(node.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit_Blunt_Data_Ref_Node(node.part_ref) + + def visit_Blunt_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + if isinstance(node.parent_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.parent_ref.indices: + self.visit(i) + if isinstance(node.part_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.part_ref.indices: + self.visit(i) + elif isinstance(node.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit_Blunt_Data_Ref_Node(node.part_ref) + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): if node.op == "=": if isinstance(node.lval, ast_internal_classes.Name_Node): @@ -133,10 +368,48 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): elif isinstance(node.lval, ast_internal_classes.Array_Subscript_Node): for i in node.lval.indices: self.visit(i) + elif isinstance(node.lval, ast_internal_classes.Data_Ref_Node): + # if isinstance(node.lval.parent_ref, ast_internal_classes.Name_Node): + # self.nodes.append(node.lval.parent_ref) + if isinstance(node.lval.parent_ref, ast_internal_classes.Array_Subscript_Node): + # self.nodes.append(node.lval.parent_ref.name) + for i in node.lval.parent_ref.indices: + self.visit(i) + if isinstance(node.lval.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit_Blunt_Data_Ref_Node(node.lval.part_ref) + elif isinstance(node.lval.part_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.lval.part_ref.indices: + self.visit(i) else: - self.visit(node.lval) - self.visit(node.rval) + if isinstance(node.lval, ast_internal_classes.Data_Ref_Node): + if isinstance(node.lval.parent_ref, ast_internal_classes.Name_Node): + self.nodes.append(node.lval.parent_ref) + elif isinstance(node.lval.parent_ref, ast_internal_classes.Array_Subscript_Node): + self.nodes.append(node.lval.parent_ref.name) + for i in node.lval.parent_ref.indices: + self.visit(i) + if isinstance(node.lval.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit_Blunt_Data_Ref_Node(node.lval.part_ref) + elif isinstance(node.lval.part_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.lval.part_ref.indices: + self.visit(i) + else: + self.visit(node.lval) + if isinstance(node.rval, ast_internal_classes.Data_Ref_Node): + if isinstance(node.rval.parent_ref, ast_internal_classes.Name_Node): + self.nodes.append(node.rval.parent_ref) + elif isinstance(node.rval.parent_ref, ast_internal_classes.Array_Subscript_Node): + self.nodes.append(node.rval.parent_ref.name) + for i in node.rval.parent_ref.indices: + self.visit(i) + if isinstance(node.rval.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit_Blunt_Data_Ref_Node(node.rval.part_ref) + elif isinstance(node.rval.part_ref, ast_internal_classes.Array_Subscript_Node): + for i in node.rval.part_ref.indices: + self.visit(i) + else: + self.visit(node.rval) class FindOutputs(NodeVisitor): @@ -144,15 +417,47 @@ class FindOutputs(NodeVisitor): Finds all outputs (writes) in the AST node and its children :return: List of names """ - def __init__(self): + + def __init__(self, thourough=False): + self.thourough = thourough self.nodes: List[ast_internal_classes.Name_Node] = [] + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + for i in node.args: + if isinstance(i, ast_internal_classes.Name_Node): + if self.thourough: + self.nodes.append(i) + elif isinstance(i, ast_internal_classes.Array_Subscript_Node): + if self.thourough: + self.nodes.append(i.name) + for j in i.indices: + self.visit(j) + elif isinstance(i, ast_internal_classes.Data_Ref_Node): + if isinstance(i.parent_ref, ast_internal_classes.Name_Node): + if self.thourough: + self.nodes.append(i.parent_ref) + elif isinstance(i.parent_ref, ast_internal_classes.Array_Subscript_Node): + if self.thourough: + self.nodes.append(i.parent_ref.name) + for j in i.parent_ref.indices: + self.visit(j) + self.visit(i.part_ref) + self.visit(i) + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): if node.op == "=": if isinstance(node.lval, ast_internal_classes.Name_Node): self.nodes.append(node.lval) elif isinstance(node.lval, ast_internal_classes.Array_Subscript_Node): self.nodes.append(node.lval.name) + elif isinstance(node.lval, ast_internal_classes.Data_Ref_Node): + if isinstance(node.lval.parent_ref, ast_internal_classes.Name_Node): + self.nodes.append(node.lval.parent_ref) + elif isinstance(node.lval.parent_ref, ast_internal_classes.Array_Subscript_Node): + self.nodes.append(node.lval.parent_ref.name) + for i in node.lval.parent_ref.indices: + self.visit(i) + self.visit(node.rval) @@ -161,6 +466,7 @@ class FindFunctionCalls(NodeVisitor): Finds all function calls in the AST node and its children :return: List of names """ + def __init__(self): self.nodes: List[ast_internal_classes.Name_Node] = [] @@ -170,13 +476,189 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): self.visit(i) -class CallToArray(NodeTransformer): +class StructLister(NodeVisitor): """ Fortran does not differentiate between arrays and functions. We need to go over and convert all function calls to arrays. So, we create a closure of all math and defined functions and create array expressions for the others. """ + + def __init__(self): + + self.structs = [] + self.names = [] + + def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node): + self.names.append(node.name.name) + if node.procedure_part is not None: + if len(node.procedure_part.procedures) > 0: + node.is_class = True + self.structs.append(node) + return + node.is_class = False + self.structs.append(node) + + +class StructDependencyLister(NodeVisitor): + def __init__(self, names=None): + self.names = names + self.structs_used = [] + self.is_pointer = [] + self.pointer_names = [] + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + if node.type in self.names: + self.structs_used.append(node.type) + self.is_pointer.append(node.alloc) + self.pointer_names.append(node.name) + + +class StructMemberLister(NodeVisitor): + def __init__(self): + self.members = [] + self.is_pointer = [] + self.pointer_names = [] + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + self.members.append(node.type) + self.is_pointer.append(node.alloc) + self.pointer_names.append(node.name) + + +class FindStructDefs(NodeVisitor): + def __init__(self, name=None): + self.name = name + self.structs = [] + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + if node.type == self.name: + self.structs.append(node.name) + + +class FindStructUses(NodeVisitor): + def __init__(self, names=None, target=None): + self.names = names + self.target = target + self.nodes = [] + + def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + + if isinstance(node.parent_ref, ast_internal_classes.Name_Node): + parent_name = node.parent_ref.name + elif isinstance(node.parent_ref, ast_internal_classes.Array_Subscript_Node): + parent_name = node.parent_ref.name.name + elif isinstance(node.parent_ref, ast_internal_classes.Data_Ref_Node): + raise NotImplementedError("Data ref node not implemented for not name or array") + self.visit(node.parent_ref) + parent_name = None + else: + + raise NotImplementedError("Data ref node not implemented for not name or array") + if isinstance(node.part_ref, ast_internal_classes.Name_Node): + part_name = node.part_ref.name + elif isinstance(node.part_ref, ast_internal_classes.Array_Subscript_Node): + part_name = node.part_ref.name.name + elif isinstance(node.part_ref, ast_internal_classes.Data_Ref_Node): + self.visit(node.part_ref) + if isinstance(node.part_ref.parent_ref, ast_internal_classes.Name_Node): + part_name = node.part_ref.parent_ref.name + elif isinstance(node.part_ref.parent_ref, ast_internal_classes.Array_Subscript_Node): + part_name = node.part_ref.parent_ref.name.name + + else: + raise NotImplementedError("Data ref node not implemented for not name or array") + if part_name == self.target and parent_name in self.names: + self.nodes.append(node) + + +class StructPointerChecker(NodeVisitor): + def __init__(self, parent_struct, pointed_struct, pointer_name, structs_lister, struct_dep_graph, analysis): + self.parent_struct = [parent_struct] + self.pointed_struct = [pointed_struct] + self.pointer_name = [pointer_name] + self.nodes = [] + self.connection = [] + self.structs_lister = structs_lister + self.struct_dep_graph = struct_dep_graph + if analysis == "full": + start_idx = 0 + end_idx = 1 + while start_idx != end_idx: + for i in struct_dep_graph.in_edges(self.parent_struct[start_idx]): + found = False + for parent, child in zip(self.parent_struct, self.pointed_struct): + if i[0] == parent and i[1] == child: + found = True + break + if not found: + self.parent_struct.append(i[0]) + self.pointed_struct.append(i[1]) + self.pointer_name.append(struct_dep_graph.get_edge_data(i[0], i[1])["point_name"]) + end_idx += 1 + start_idx += 1 + + def visit_Main_Program_Node(self, node: ast_internal_classes.Main_Program_Node): + for parent, pointer in zip(self.parent_struct, self.pointer_name): + finder = FindStructDefs(parent) + finder.visit(node.specification_part) + struct_names = finder.structs + use_finder = FindStructUses(struct_names, pointer) + use_finder.visit(node.execution_part) + self.nodes += use_finder.nodes + self.connection.append([parent, pointer, struct_names, use_finder.nodes]) + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + for parent, pointer in zip(self.parent_struct, self.pointer_name): + + finder = FindStructDefs(parent) + if node.specification_part is not None: + finder.visit(node.specification_part) + struct_names = finder.structs + use_finder = FindStructUses(struct_names, pointer) + if node.execution_part is not None: + use_finder.visit(node.execution_part) + self.nodes += use_finder.nodes + self.connection.append([parent, pointer, struct_names, use_finder.nodes]) + + +class StructPointerEliminator(NodeTransformer): + def __init__(self, parent_struct, pointed_struct, pointer_name): + self.parent_struct = parent_struct + self.pointed_struct = pointed_struct + self.pointer_name = pointer_name + + def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node): + if node.name.name == self.parent_struct: + newnode = ast_internal_classes.Derived_Type_Def_Node(name=node.name, parent=node.parent) + component_part = ast_internal_classes.Component_Part_Node(component_def_stmts=[], parent=node.parent) + for i in node.component_part.component_def_stmts: + + vardecl = [] + for k in i.vars.vardecl: + if k.name == self.pointer_name and k.alloc == True and k.type == self.pointed_struct: + # print("Eliminating pointer "+self.pointer_name+" of type "+ k.type +" in struct "+self.parent_struct) + continue + else: + vardecl.append(k) + if vardecl != []: + component_part.component_def_stmts.append(ast_internal_classes.Data_Component_Def_Stmt_Node( + vars=ast_internal_classes.Decl_Stmt_Node(vardecl=vardecl, parent=node.parent), + parent=node.parent)) + newnode.component_part = component_part + return newnode + else: + return node + + +class StructConstructorToFunctionCall(NodeTransformer): + """ + Fortran does not differentiate between structure constructors and functions without arguments. + We need to go over and convert all structure constructors that are in fact functions and transform them. + So, we create a closure of all math and defined functions and + transform if necessary. + """ + def __init__(self, funcs=None): if funcs is None: funcs = [] @@ -188,87 +670,188 @@ def __init__(self, funcs=None): "__dace_epsilon", *FortranIntrinsics.function_names() ] + def visit_Structure_Constructor_Node(self, node: ast_internal_classes.Structure_Constructor_Node): + if isinstance(node.name, str): + return node + if node.name is None: + raise ValueError("Structure name is None") + return ast_internal_classes.Char_Literal_Node(value="Error!", type="CHARACTER") + found = False + for i in self.funcs: + if i.name == node.name.name: + found = True + break + if node.name.name in self.excepted_funcs or found: + processed_args = [] + for i in node.args: + arg = StructConstructorToFunctionCall(self.funcs).visit(i) + processed_args.append(arg) + node.args = processed_args + return ast_internal_classes.Call_Expr_Node( + name=ast_internal_classes.Name_Node(name=node.name.name, type="VOID", line_number=node.line_number), + args=node.args, line_number=node.line_number, type="VOID",parent=node.parent) + + else: + return node + + +class CallToArray(NodeTransformer): + """ + Fortran does not differentiate between arrays and functions. + We need to go over and convert all function calls to arrays. + So, we create a closure of all math and defined functions and + create array expressions for the others. + """ + + def __init__(self, funcs: FindFunctionAndSubroutines, dict=None): + self.funcs = funcs + self.rename_dict = dict + + from dace.frontend.fortran.intrinsics import FortranIntrinsics + self.excepted_funcs = [ + "malloc", "pow", "cbrt", "__dace_sign", "__dace_allocated", "tanh", "atan2", + "__dace_epsilon", "__dace_exit", "surrtpk", "surrtab", "surrtrf", "abor1", + *FortranIntrinsics.function_names() + ] + # + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): if isinstance(node.name, str): return node - if node.name.name in self.excepted_funcs or node.name in self.funcs: + assert node.name is not None, f"not a valid call expression, got: {node} / {type(node)}" + name = node.name.name + + found_in_names = name in [i.name for i in self.funcs.names] + found_in_renames = False + if self.rename_dict is not None: + for k, v in self.rename_dict.items(): + for original_name, replacement_names in v.items(): + if isinstance(replacement_names, str): + if name == replacement_names: + found_in_renames = True + module = k + original_one = original_name + node.name.name = original_name + print(f"Found {name} in {module} with original name {original_one}") + break + elif isinstance(replacement_names, list): + for repl in replacement_names: + if name == repl: + found_in_renames = True + module = k + original_one = original_name + node.name.name = original_name + print(f"Found in list {name} in {module} with original name {original_one}") + break + else: + raise ValueError(f"Invalid type {type(replacement_names)} for {replacement_names}") + + # TODO Deconproc is a special case, we need to handle it differently - this is just s quick workaround + if name.startswith( + "__dace_") or name in self.excepted_funcs or found_in_renames or found_in_names or name in self.funcs.iblocks: processed_args = [] for i in node.args: - arg = CallToArray(self.funcs).visit(i) + arg = CallToArray(self.funcs, self.rename_dict).visit(i) processed_args.append(arg) node.args = processed_args return node - indices = [CallToArray(self.funcs).visit(i) for i in node.args] - return ast_internal_classes.Array_Subscript_Node(name=node.name, indices=indices) + indices = [CallToArray(self.funcs, self.rename_dict).visit(i) for i in node.args] + # Array subscript cannot be empty. + assert indices + return ast_internal_classes.Array_Subscript_Node(name=node.name, type=node.type, indices=indices, + line_number=node.line_number) -class CallExtractorNodeLister(NodeVisitor): +class ArgumentExtractorNodeLister(NodeVisitor): """ - Finds all function calls in the AST node and its children that have to be extracted into independent expressions + Finds all arguments in function calls in the AST node and its children that have to be extracted into independent expressions """ + def __init__(self): self.nodes: List[ast_internal_classes.Call_Expr_Node] = [] def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node): return + + def visit_If_Then_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node): + return def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): stop = False - if hasattr(node, "subroutine"): - if node.subroutine is True: - stop = True + #if hasattr(node, "subroutine"): + # if node.subroutine is True: + # stop = True from dace.frontend.fortran.intrinsics import FortranIntrinsics if not stop and node.name.name not in [ - "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() ]: - self.nodes.append(node) + for i in node.args: + if isinstance(i, (ast_internal_classes.Name_Node, ast_internal_classes.Literal, + ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node, + ast_internal_classes.Actual_Arg_Spec_Node)): + continue + else: + self.nodes.append(i) return self.generic_visit(node) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return -class CallExtractor(NodeTransformer): +class ArgumentExtractor(NodeTransformer): """ - Uses the CallExtractorNodeLister to find all function calls + Uses the ArgumentExtractorNodeLister to find all function calls in the AST node and its children that have to be extracted into independent expressions It then creates a new temporary variable for each of them and replaces the call with the variable. """ - def __init__(self, count=0): + + def __init__(self, program, count=0): self.count = count + self.program = program + + ParentScopeAssigner().visit(program) + self.scope_vars = ScopeVarsDeclarations(program) + self.scope_vars.visit(program) def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics - if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]: + if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", + *FortranIntrinsics.call_extraction_exemptions()]: return self.generic_visit(node) - if hasattr(node, "subroutine"): - if node.subroutine is True: - return self.generic_visit(node) + #if node.subroutine: + # return self.generic_visit(node) if not hasattr(self, "count"): self.count = 0 - else: - self.count = self.count + 1 tmp = self.count - + result = ast_internal_classes.Call_Expr_Node(type=node.type, subroutine=node.subroutine, + name=node.name, args=[], line_number=node.line_number, parent=node.parent) for i, arg in enumerate(node.args): # Ensure we allow to extract function calls from arguments - node.args[i] = self.visit(arg) - - return ast_internal_classes.Name_Node(name="tmp_call_" + str(tmp - 1)) + if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Literal, + ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node, + ast_internal_classes.Actual_Arg_Spec_Node)): + result.args.append(arg) + else: + result.args.append(ast_internal_classes.Name_Node(name="tmp_arg_" + str(tmp), type='VOID')) + tmp = tmp + 1 + self.count = tmp + return result def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): newbody = [] for child in node.execution: - lister = CallExtractorNodeLister() + lister = ArgumentExtractorNodeLister() lister.visit(child) res = lister.nodes for i in res: if i == child: res.pop(res.index(i)) + if res is not None: + # Variables are counted from 0...end, starting from main node, to all calls nested # in main node arguments. # However, we need to define nested ones first. @@ -276,115 +859,471 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No temp = self.count + len(res) - 1 for i in reversed(range(0, len(res))): - newbody.append( + if isinstance(res[i], ast_internal_classes.Data_Ref_Node): + struct_def, cur_var = self.program.structures.find_definition(self.scope_vars, res[i]) + + var_type = cur_var.type + else: + var_type = res[i].type + + node.parent.specification_part.specifications.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ ast_internal_classes.Var_Decl_Node( - name="tmp_call_" + str(temp), - type=res[i].type, - sizes=None + name="tmp_arg_" + str(temp), + type=var_type, + sizes=None, + init=None ) - ])) + ]) + ) newbody.append( ast_internal_classes.BinOp_Node(op="=", - lval=ast_internal_classes.Name_Node(name="tmp_call_" + - str(temp), + lval=ast_internal_classes.Name_Node(name="tmp_arg_" + + str(temp), type=res[i].type), rval=res[i], - line_number=child.line_number)) + line_number=child.line_number,parent=child.parent)) temp = temp - 1 - if isinstance(child, ast_internal_classes.Call_Expr_Node): - new_args = [] - if hasattr(child, "args"): - for i in child.args: - new_args.append(self.visit(i)) - new_child = ast_internal_classes.Call_Expr_Node(type=child.type, - name=child.name, - args=new_args, - line_number=child.line_number) - newbody.append(new_child) - else: - newbody.append(self.visit(child)) + + newbody.append(self.visit(child)) return ast_internal_classes.Execution_Part_Node(execution=newbody) -class ParentScopeAssigner(NodeVisitor): - """ - For each node, it assigns its parent scope - program, subroutine, function. - If the parent node is one of the "parent" types, we assign it as the parent. - Otherwise, we look for the parent of my parent to cover nested AST nodes within - a single scope. - """ - def __init__(self): - pass +class FunctionCallTransformer(NodeTransformer): + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): + if isinstance(node.rval, ast_internal_classes.Call_Expr_Node): + if hasattr(node.rval, "subroutine"): + if node.rval.subroutine is True: + return self.generic_visit(node) + if node.rval.name.name.find("__dace_") != -1: + return self.generic_visit(node) + if node.op != "=": + return self.generic_visit(node) + args = node.rval.args + lval = node.lval + args.append(lval) + return (ast_internal_classes.Call_Expr_Node(type=node.rval.type, + name=ast_internal_classes.Name_Node( + name=node.rval.name.name + "_srt", type=node.rval.type), + args=args, + subroutine=True, + line_number=node.line_number, parent=node.parent)) - def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None): + else: + return self.generic_visit(node) - parent_node_types = [ - ast_internal_classes.Subroutine_Subprogram_Node, - ast_internal_classes.Function_Subprogram_Node, - ast_internal_classes.Main_Program_Node, - ast_internal_classes.Module_Node - ] - if parent_node is not None and type(parent_node) in parent_node_types: - node.parent = parent_node - elif parent_node is not None: - node.parent = parent_node.parent +class NameReplacer(NodeTransformer): + """ + Replaces all occurences of a name with another name + """ - # Copied from `generic_visit` to recursively parse all leafs - for field, value in iter_fields(node): - if isinstance(value, list): - for item in value: - if isinstance(item, ast_internal_classes.FNode): - self.visit(item, node) - elif isinstance(value, ast_internal_classes.FNode): - self.visit(value, node) + def __init__(self, old_name: str, new_name: str): + self.old_name = old_name + self.new_name = new_name - return node + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + if node.name == self.old_name: + return ast_internal_classes.Name_Node(name=self.new_name, type=node.type) + else: + return self.generic_visit(node) -class ScopeVarsDeclarations(NodeVisitor): - """ - Creates a mapping (scope name, variable name) -> variable declaration. - The visitor is used to access information on variable dimension, sizes, and offsets. +class FunctionToSubroutineDefiner(NodeTransformer): + """ + Transforms all function definitions into subroutine definitions """ - def __init__(self): - - self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {} + def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Subprogram_Node): + assert node.ret + ret = node.ret + + found = False + if node.specification_part is not None: + for j in node.specification_part.specifications: + + for k in j.vardecl: + if node.ret != None: + if k.name == ret.name: + j.vardecl[j.vardecl.index(k)].name = node.name.name + "__ret" + found = True + if k.name == node.name.name: + j.vardecl[j.vardecl.index(k)].name = node.name.name + "__ret" + found = True + break - def get_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> ast_internal_classes.FNode: - return self.scope_vars[(self._scope_name(scope), variable_name)] + if not found: - def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + var = ast_internal_classes.Var_Decl_Node( + name=node.name.name + "__ret", + type='VOID' + ) + stmt_node = ast_internal_classes.Decl_Stmt_Node(vardecl=[var], line_number=node.line_number) - parent_name = self._scope_name(node.parent) - var_name = node.name - self.scope_vars[(parent_name, var_name)] = node + if node.specification_part is not None: + node.specification_part.specifications.append(stmt_node) + else: + node.specification_part = ast_internal_classes.Specification_Part_Node( + specifications=[stmt_node], + symbols=None, + interface_blocks=None, + uses=None, + typedecls=None, + enums=None + ) + + # We should always be able to tell a functions return _variable_ (i.e., not type, which we also should be able + # to tell). + assert node.ret + execution_part = NameReplacer(ret.name, node.name.name + "__ret").visit(node.execution_part) + args = node.args + args.append(ast_internal_classes.Name_Node(name=node.name.name + "__ret", type=node.type)) + return ast_internal_classes.Subroutine_Subprogram_Node( + name=ast_internal_classes.Name_Node(name=node.name.name + "_srt", type=node.type), + args=args, + specification_part=node.specification_part, + execution_part=execution_part, + subroutine=True, + line_number=node.line_number, + elemental=node.elemental) - def _scope_name(self, scope: ast_internal_classes.FNode) -> str: - if isinstance(scope, ast_internal_classes.Main_Program_Node): - return scope.name.name.name - else: - return scope.name.name -class IndexExtractorNodeLister(NodeVisitor): +class CallExtractorNodeLister(NodeVisitor): + """ + Finds all function calls in the AST node and its children that have to be extracted into independent expressions + """ + + def __init__(self,root=None): + self.root = root + self.nodes: List[ast_internal_classes.Call_Expr_Node] = [] + + + def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node): + self.generic_visit(node.init) + self.generic_visit(node.cond) + return + + def visit_If_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node): + self.generic_visit(node.cond) + return + + def visit_While_Stmt_Node(self, node: ast_internal_classes.While_Stmt_Node): + self.generic_visit(node.cond) + return + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + stop = False + if self.root==node: + return self.generic_visit(node) + if isinstance(self.root, ast_internal_classes.BinOp_Node): + if node == self.root.rval and isinstance(self.root.lval, ast_internal_classes.Name_Node): + return self.generic_visit(node) + if hasattr(node, "subroutine"): + if node.subroutine is True: + stop = True + + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if not stop and node.name.name not in [ + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + ]: + self.nodes.append(node) + #return self.generic_visit(node) + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + return + + +class CallExtractor(NodeTransformer): + """ + Uses the CallExtractorNodeLister to find all function calls + in the AST node and its children that have to be extracted into independent expressions + It then creates a new temporary variable for each of them and replaces the call with the variable. + """ + + def __init__(self, count=0): + self.count = count + + + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", + *FortranIntrinsics.call_extraction_exemptions()]: + return self.generic_visit(node) + if hasattr(node, "subroutine"): + if node.subroutine is True: + return self.generic_visit(node) + if not hasattr(self, "count"): + self.count = 0 + else: + self.count = self.count + 1 + tmp = self.count + + #for i, arg in enumerate(node.args): + # # Ensure we allow to extract function calls from arguments + # node.args[i] = self.visit(arg) + + return ast_internal_classes.Name_Node(name="tmp_call_" + str(tmp - 1)) + + # def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification_Part_Node): + # newspec = [] + + # for i in node.specifications: + # if not isinstance(i, ast_internal_classes.Decl_Stmt_Node): + # newspec.append(self.visit(i)) + # else: + # newdecl = [] + # for var in i.vardecl: + # lister = CallExtractorNodeLister() + # lister.visit(var) + # res = lister.nodes + # for j in res: + # if j == var: + # res.pop(res.index(j)) + # if len(res) > 0: + # temp = self.count + len(res) - 1 + # for ii in reversed(range(0, len(res))): + # newdecl.append( + # ast_internal_classes.Var_Decl_Node( + # name="tmp_call_" + str(temp), + # type=res[ii].type, + # sizes=None, + # line_number=var.line_number, + # init=res[ii], + # ) + # ) + # newdecl.append( + # ast_internal_classes.Var_Decl_Node( + # name="tmp_call_" + str(temp), + # type=res[ii].type, + # sizes=None, + # line_number=var.line_number, + # init=res[ii], + # ) + # ) + # temp = temp - 1 + # newdecl.append(self.visit(var)) + # newspec.append(ast_internal_classes.Decl_Stmt_Node(vardecl=newdecl)) + # return ast_internal_classes.Specification_Part_Node(specifications=newspec, symbols=node.symbols, + # typedecls=node.typedecls, uses=node.uses, enums=node.enums, + # interface_blocks=node.interface_blocks) + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + + oldbody = node.execution + changes_made=True + while changes_made: + changes_made=False + newbody = [] + for child in oldbody: + lister = CallExtractorNodeLister(child) + lister.visit(child) + res = lister.nodes + + if len(res)> 0: + changes_made=True + # Variables are counted from 0...end, starting from main node, to all calls nested + # in main node arguments. + # However, we need to define nested ones first. + # We go in reverse order, counting from end-1 to 0. + temp = self.count + len(res) - 1 + for i in reversed(range(0, len(res))): + newbody.append( + ast_internal_classes.Decl_Stmt_Node(vardecl=[ + ast_internal_classes.Var_Decl_Node( + name="tmp_call_" + str(temp), + type=res[i].type, + sizes=None, + init=None + ) + ])) + newbody.append( + ast_internal_classes.BinOp_Node(op="=", + lval=ast_internal_classes.Name_Node( + name="tmp_call_" + str(temp), type=res[i].type), + rval=res[i], line_number=child.line_number,parent=child.parent)) + temp = temp - 1 + if isinstance(child, ast_internal_classes.Call_Expr_Node): + new_args = [] + for i in child.args: + new_args.append(self.visit(i)) + new_child = ast_internal_classes.Call_Expr_Node(type=child.type, subroutine=child.subroutine, + name=child.name, args=new_args, + line_number=child.line_number, parent=child.parent) + newbody.append(new_child) + elif isinstance(child, ast_internal_classes.BinOp_Node): + if isinstance(child.lval,ast_internal_classes.Name_Node) and isinstance (child.rval, ast_internal_classes.Call_Expr_Node): + new_args = [] + for i in child.rval.args: + new_args.append(self.visit(i)) + new_child = ast_internal_classes.Call_Expr_Node(type=child.rval.type, subroutine=child.rval.subroutine, + name=child.rval.name, args=new_args, + line_number=child.rval.line_number, parent=child.rval.parent) + newbody.append(ast_internal_classes.BinOp_Node(op=child.op, + lval=child.lval, + rval=new_child, line_number=child.line_number,parent=child.parent)) + else: + newbody.append(self.visit(child)) + else: + newbody.append(self.visit(child)) + oldbody = newbody + + return ast_internal_classes.Execution_Part_Node(execution=newbody) + + +class ParentScopeAssigner(NodeVisitor): + """ + For each node, it assigns its parent scope - program, subroutine, function. + + If the parent node is one of the "parent" types, we assign it as the parent. + Otherwise, we look for the parent of my parent to cover nested AST nodes within + a single scope. + """ + + def __init__(self): + pass + + def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None): + + parent_node_types = [ + ast_internal_classes.Subroutine_Subprogram_Node, + ast_internal_classes.Function_Subprogram_Node, + ast_internal_classes.Main_Program_Node, + ast_internal_classes.Module_Node + ] + + if parent_node is not None and type(parent_node) in parent_node_types: + node.parent = parent_node + elif parent_node is not None: + node.parent = parent_node.parent + + # Copied from `generic_visit` to recursively parse all leafs + for field, value in iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item, node) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value, node) + + return node + + +class ModuleVarsDeclarations(NodeVisitor): + """ + Creates a mapping (scope name, variable name) -> variable declaration. + + The visitor is used to access information on variable dimension, sizes, and offsets. + """ + + def __init__(self): # , module_name: str): + + self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {} + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + var_name = node.name + self.scope_vars[var_name] = node + + def visit_Symbol_Decl_Node(self, node: ast_internal_classes.Symbol_Decl_Node): + var_name = node.name + self.scope_vars[var_name] = node + + +class ScopeVarsDeclarations(NodeVisitor): + """ + Creates a mapping (scope name, variable name) -> variable declaration. + + The visitor is used to access information on variable dimension, sizes, and offsets. + """ + + def __init__(self, ast): + + self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {} + if hasattr(ast, "module_declarations"): + self.module_declarations = ast.module_declarations + else: + self.module_declarations = {} + + def get_var(self, scope: Optional[Union[ast_internal_classes.FNode, str]], + variable_name: str) -> ast_internal_classes.FNode: + + if scope is not None and self.contains_var(scope, variable_name): + return self.scope_vars[(self._scope_name(scope), variable_name)] + elif variable_name in self.module_declarations: + return self.module_declarations[variable_name] + else: + raise RuntimeError( + f"Couldn't find the declaration of variable {variable_name} in function {self._scope_name(scope)}!") + + def contains_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> bool: + return (self._scope_name(scope), variable_name) in self.scope_vars + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + + parent_name = self._scope_name(node.parent) + var_name = node.name + self.scope_vars[(parent_name, var_name)] = node + + def visit_Symbol_Decl_Node(self, node: ast_internal_classes.Symbol_Decl_Node): + + parent_name = self._scope_name(node.parent) + var_name = node.name + self.scope_vars[(parent_name, var_name)] = node + + def _scope_name(self, scope: ast_internal_classes.FNode) -> str: + if isinstance(scope, ast_internal_classes.Main_Program_Node): + return scope.name.name.name + elif isinstance(scope, str): + return scope + else: + return scope.name.name + + +class IndexExtractorNodeLister(NodeVisitor): """ Finds all array subscript expressions in the AST node and its children that have to be extracted into independent expressions """ + def __init__(self): self.nodes: List[ast_internal_classes.Array_Subscript_Node] = [] + self.current_parent: Optional[ast_internal_classes.Data_Ref_Node] = None def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]: return self.generic_visit(node) else: + for arg in node.args: + self.visit(arg) return def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): - self.nodes.append(node) + + old_current_parent = self.current_parent + self.current_parent = None + for i in node.indices: + self.visit(i) + self.current_parent = old_current_parent + + self.nodes.append((node, self.current_parent)) + + # disable structure parent node for array indices + + def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + + set_node = False + if self.current_parent is None: + self.current_parent = node + set_node = True + + self.visit(node.parent_ref) + self.visit(node.part_ref) + + if set_node: + set_node = False + self.current_parent = None def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return @@ -400,35 +1339,52 @@ class IndexExtractor(NodeTransformer): - ParentScopeAssigner to ensure that each node knows its scope assigner. - ScopeVarsDeclarations to aggregate all variable declarations for each function. """ + def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = False, count=0): self.count = count self.normalize_offsets = normalize_offsets + self.program = ast + self.replacements = {} if normalize_offsets: ParentScopeAssigner().visit(ast) - self.scope_vars = ScopeVarsDeclarations() + self.scope_vars = ScopeVarsDeclarations(ast) self.scope_vars.visit(ast) + self.structures = ast.structures def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]: return self.generic_visit(node) else: + + new_args = [] + for arg in node.args: + new_args.append(self.visit(arg)) + node.args = new_args return node def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): - - tmp = self.count new_indices = [] + for i in node.indices: + new_indices.append(self.visit(i)) + + tmp = self.count + newer_indices = [] + for i in new_indices: if isinstance(i, ast_internal_classes.ParDecl_Node): - new_indices.append(i) + newer_indices.append(i) else: - new_indices.append(ast_internal_classes.Name_Node(name="tmp_index_" + str(tmp))) + + newer_indices.append(ast_internal_classes.Name_Node(name="tmp_index_" + str(tmp))) + self.replacements["tmp_index_" + str(tmp)] = (i, node.name.name) tmp = tmp + 1 self.count = tmp - return ast_internal_classes.Array_Subscript_Node(name=node.name, indices=new_indices) + + return ast_internal_classes.Array_Subscript_Node(name=node.name, type=node.type, indices=newer_indices, + line_number=node.line_number) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): newbody = [] @@ -439,10 +1395,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No res = lister.nodes temp = self.count - + tmp_child = self.visit(child) if res is not None: - for j in res: + for j, parent_node in res: for idx, i in enumerate(j.indices): + if isinstance(i, ast_internal_classes.ParDecl_Node): continue else: @@ -453,29 +1410,46 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No ast_internal_classes.Var_Decl_Node(name=tmp_name, type="INTEGER", sizes=None, + init=None, line_number=child.line_number) ], - line_number=child.line_number)) + line_number=child.line_number)) if self.normalize_offsets: # Find the offset of a variable to which we are assigning var_name = "" if isinstance(j, ast_internal_classes.Name_Node): var_name = j.name + variable = self.scope_vars.get_var(child.parent, var_name) + elif parent_node is not None: + struct, variable = self.structures.find_definition( + self.scope_vars, parent_node, j.name + ) + var_name = j.name.name else: var_name = j.name.name - variable = self.scope_vars.get_var(child.parent, var_name) + variable = self.scope_vars.get_var(child.parent, var_name) offset = variable.offsets[idx] + # it can be a symbol - Name_Node - or a value + + + if not isinstance(offset, ast_internal_classes.Name_Node) and not isinstance(offset,ast_internal_classes.BinOp_Node): + #check if offset is a number + try: + offset = int(offset) + except: + raise ValueError(f"Offset {offset} is not a number") + offset = ast_internal_classes.Int_Literal_Node(value=str(offset)) newbody.append( ast_internal_classes.BinOp_Node( op="=", lval=ast_internal_classes.Name_Node(name=tmp_name), rval=ast_internal_classes.BinOp_Node( op="-", - lval=i, - rval=ast_internal_classes.Int_Literal_Node(value=str(offset)), - line_number=child.line_number), + lval=self.replacements[tmp_name][0], + rval=offset, + line_number=child.line_number,parent=child.parent), line_number=child.line_number)) else: newbody.append( @@ -484,11 +1458,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No lval=ast_internal_classes.Name_Node(name=tmp_name), rval=ast_internal_classes.BinOp_Node( op="-", - lval=i, + lval=self.replacements[tmp_name][0], rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number), - line_number=child.line_number)) - newbody.append(self.visit(child)) + line_number=child.line_number,parent=child.parent), + line_number=child.line_number,parent=child.parent)) + newbody.append(tmp_child) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -496,6 +1470,7 @@ class SignToIf(NodeTransformer): """ Transforms all sign expressions into if statements """ + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): if isinstance(node.rval, ast_internal_classes.Call_Expr_Node) and node.rval.name.name == "__dace_sign": args = node.rval.args @@ -503,7 +1478,7 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): cond = ast_internal_classes.BinOp_Node(op=">=", rval=ast_internal_classes.Real_Literal_Node(value="0.0"), lval=args[1], - line_number=node.line_number) + line_number=node.line_number,parent=node.parent) body_if = ast_internal_classes.Execution_Part_Node(execution=[ ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval), op="=", @@ -511,26 +1486,30 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): name=ast_internal_classes.Name_Node(name="abs"), type="DOUBLE", args=[copy.deepcopy(args[0])], - line_number=node.line_number), - line_number=node.line_number) + line_number=node.line_number,parent=node.parent, + subroutine=False,), + + line_number=node.line_number,parent=node.parent) ]) body_else = ast_internal_classes.Execution_Part_Node(execution=[ ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval), op="=", rval=ast_internal_classes.UnOp_Node( op="-", + type="VOID", lval=ast_internal_classes.Call_Expr_Node( name=ast_internal_classes.Name_Node(name="abs"), - type="DOUBLE", args=[copy.deepcopy(args[0])], - line_number=node.line_number), - line_number=node.line_number), - line_number=node.line_number) + type="DOUBLE", + subroutine=False, + line_number=node.line_number,parent=node.parent), + line_number=node.line_number,parent=node.parent), + line_number=node.line_number,parent=node.parent) ]) return (ast_internal_classes.If_Stmt_Node(cond=cond, body=body_if, body_else=body_else, - line_number=node.line_number)) + line_number=node.line_number,parent=node.parent)) else: return self.generic_visit(node) @@ -541,6 +1520,7 @@ class RenameArguments(NodeTransformer): Renames all arguments of a function to the names of the arguments of the function call Used when eliminating function statements """ + def __init__(self, node_args: list, call_args: list): self.node_args = node_args self.call_args = call_args @@ -556,6 +1536,7 @@ class ReplaceFunctionStatement(NodeTransformer): """ Replaces a function statement with its content, similar to propagating a macro """ + def __init__(self, statement, replacement): self.name = statement.name self.content = replacement @@ -571,6 +1552,7 @@ class ReplaceFunctionStatementPass(NodeTransformer): """ Replaces a function statement with its content, similar to propagating a macro """ + def __init__(self, statefunc: list): self.funcs = statefunc @@ -591,72 +1573,410 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): return self.generic_visit(node) -def functionStatementEliminator(node=ast_internal_classes.Program_Node): +def optionalArgsHandleFunction(func): + func.optional_args = [] + if func.specification_part is None: + return 0 + for spec in func.specification_part.specifications: + for var in spec.vardecl: + if hasattr(var, "optional") and var.optional: + func.optional_args.append((var.name, var.type)) + + vardecls = [] + new_args = [] + for i in func.args: + new_args.append(i) + for arg in func.args: + + found = False + for opt_arg in func.optional_args: + if opt_arg[0] == arg.name: + found = True + break + + if found: + + name = f'__f2dace_OPTIONAL_{arg.name}' + already_there = False + for i in func.args: + if hasattr(i, "name") and i.name == name: + already_there = True + break + if not already_there: + var = ast_internal_classes.Var_Decl_Node(name=name, + type='LOGICAL', + alloc=False, + sizes=None, + offsets=None, + kind=None, + optional=False, + init=None, + line_number=func.line_number) + new_args.append(ast_internal_classes.Name_Node(name=name)) + vardecls.append(var) + + if len(new_args) > len(func.args): + func.args.clear() + func.args = new_args + + if len(vardecls) > 0: + specifiers = [] + for i in func.specification_part.specifications: + specifiers.append(i) + specifiers.append( + ast_internal_classes.Decl_Stmt_Node( + vardecl=vardecls, + line_number=func.line_number + ) + ) + func.specification_part.specifications.clear() + func.specification_part.specifications = specifiers + + return len(new_args) + + +class OptionalArgsTransformer(NodeTransformer): + def __init__(self, funcs_with_opt_args): + self.funcs_with_opt_args = funcs_with_opt_args + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + if node.name.name not in self.funcs_with_opt_args: + return node + + # Basic assumption for positioanl arguments + # Optional arguments follow the mandatory ones + # We use that to determine which optional arguments are missing + func_decl = self.funcs_with_opt_args[node.name.name] + optional_args = len(func_decl.optional_args) + if optional_args == 0: + return node + + should_be_args = len(func_decl.args) + mandatory_args = should_be_args - optional_args * 2 + + present_args = len(node.args) + + # Remove the deduplicated variable entries acting as flags for optional args + missing_args_count = should_be_args - present_args + present_optional_args = present_args - mandatory_args + new_args = [None] * should_be_args + print("Func len args: ", len(func_decl.args)) + print("Func: ", func_decl.name.name, "Mandatory: ", mandatory_args, "Optional: ", optional_args, "Present: ", + present_args, "Missing: ", missing_args_count, "Present Optional: ", present_optional_args) + print("List: ", node.name.name, len(new_args), mandatory_args) + + if missing_args_count == 0: + return node + + for i in range(mandatory_args): + new_args[i] = node.args[i] + for i in range(mandatory_args, len(node.args)): + if len(node.args) > i: + current_arg = node.args[i] + if not isinstance(current_arg, ast_internal_classes.Actual_Arg_Spec_Node): + new_args[i] = current_arg + else: + name = current_arg.arg_name + index = 0 + for j in func_decl.optional_args: + if j[0] == name.name: + break + index = index + 1 + new_args[mandatory_args + index] = current_arg.arg + + for i in range(mandatory_args, mandatory_args + optional_args): + relative_position = i - mandatory_args + if new_args[i] is None: + dtype = func_decl.optional_args[relative_position][1] + if dtype == 'INTEGER': + new_args[i] = ast_internal_classes.Int_Literal_Node(value='0') + elif dtype == 'LOGICAL': + new_args[i] = ast_internal_classes.Bool_Literal_Node(value='0') + elif dtype == 'DOUBLE': + new_args[i] = ast_internal_classes.Real_Literal_Node(value='0') + elif dtype == 'CHAR': + new_args[i] = ast_internal_classes.Char_Literal_Node(value='0') + else: + raise NotImplementedError() + new_args[i + optional_args] = ast_internal_classes.Bool_Literal_Node(value='0') + else: + new_args[i + optional_args] = ast_internal_classes.Bool_Literal_Node(value='1') + + node.args = new_args + return node + + +def optionalArgsExpander(node=ast_internal_classes.Program_Node): """ + Adds to each optional arg a logical value specifying its status. Eliminates function statements from the AST :param node: The AST to be transformed :return: The transformed AST :note Should only be used on the program node """ - main_program = localFunctionStatementEliminator(node.main_program) - function_definitions = [localFunctionStatementEliminator(i) for i in node.function_definitions] - subroutine_definitions = [localFunctionStatementEliminator(i) for i in node.subroutine_definitions] - modules = [] - for i in node.modules: - module_function_definitions = [localFunctionStatementEliminator(j) for j in i.function_definitions] - module_subroutine_definitions = [localFunctionStatementEliminator(j) for j in i.subroutine_definitions] - modules.append( - ast_internal_classes.Module_Node( - name=i.name, - specification_part=i.specification_part, - subroutine_definitions=module_subroutine_definitions, - function_definitions=module_function_definitions, - )) - return ast_internal_classes.Program_Node(main_program=main_program, - function_definitions=function_definitions, - subroutine_definitions=subroutine_definitions, - modules=modules) + modified_functions = {} -def localFunctionStatementEliminator(node: ast_internal_classes.FNode): - """ - Eliminates function statements from the AST - :param node: The AST to be transformed - :return: The transformed AST - """ - spec = node.specification_part.specifications - exec = node.execution_part.execution - new_exec = exec.copy() - to_change = [] - for i in exec: - if isinstance(i, ast_internal_classes.BinOp_Node): - if i.op == "=": - if isinstance(i.lval, ast_internal_classes.Call_Expr_Node) or isinstance( - i.lval, ast_internal_classes.Structure_Constructor_Node): - function_statement_name = i.lval.name - is_actually_function_statement = False - # In Fortran, function statement are defined as scalar values, - # but called as arrays, so by identifiying that it is called as - # a call_expr or structure_constructor, we also need to match - # the specification part and see that it is scalar rather than an array. - found = False - for j in spec: - if found: - break - for k in j.vardecl: - if k.name == function_statement_name.name: - if k.sizes is None: - is_actually_function_statement = True - function_statement_type = k.type - j.vardecl.remove(k) - found = True - break - if is_actually_function_statement: - to_change.append([i.lval, i.rval]) - new_exec.remove(i) + for func in node.subroutine_definitions: + if optionalArgsHandleFunction(func): + modified_functions[func.name.name] = func + for mod in node.modules: + for func in mod.subroutine_definitions: + if optionalArgsHandleFunction(func): + modified_functions[func.name.name] = func - else: - #There are no function statements after the first one that isn't a function statement + node = OptionalArgsTransformer(modified_functions).visit(node) + + return node + +class AllocatableFunctionLister(NodeVisitor): + + def __init__(self): + self.functions = {} + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + + for i in node.specification_part.specifications: + + vars = [] + if isinstance(i, ast_internal_classes.Decl_Stmt_Node): + + for var_decl in i.vardecl: + if var_decl.alloc: + + # we are only interestd in adding flag if it's an arg + found = False + for arg in node.args: + assert isinstance(arg, ast_internal_classes.Name_Node) + + if var_decl.name == arg.name: + found = True + break + + if found: + vars.append(var_decl.name) + + if len(vars) > 0: + self.functions[node.name.name] = vars + +class AllocatableReplacerVisitor(NodeVisitor): + + def __init__(self, functions_with_alloc): + self.allocate_var_names = [] + self.deallocate_var_names = [] + self.call_nodes = [] + self.functions_with_alloc = functions_with_alloc + + def visit_Allocate_Stmt_Node(self, node: ast_internal_classes.Allocate_Stmt_Node): + + for var in node.allocation_list: + self.allocate_var_names.append(var.name.name) + + def visit_Deallocate_Stmt_Node(self, node: ast_internal_classes.Deallocate_Stmt_Node): + + for var in node.list: + self.deallocate_var_names.append(var.name) + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + for node.name.name in self.functions_with_alloc: + self.call_nodes.append(node) + +class AllocatableReplacerTransformer(NodeTransformer): + + def __init__(self, functions_with_alloc: Dict[str, List[str]]): + self.functions_with_alloc = functions_with_alloc + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + + newbody = [] + + for child in node.execution: + + lister = AllocatableReplacerVisitor(self.functions_with_alloc) + lister.visit(child) + + for alloc_node in lister.allocate_var_names: + + name = f'__f2dace_ALLOCATED_{alloc_node}' + newbody.append( + ast_internal_classes.BinOp_Node( + op="=", + lval=ast_internal_classes.Name_Node(name=name), + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number, + parent=child.parent + ) + ) + + for dealloc_node in lister.deallocate_var_names: + + name = f'__f2dace_ALLOCATED_{dealloc_node}' + newbody.append( + ast_internal_classes.BinOp_Node( + op="=", + lval=ast_internal_classes.Name_Node(name=name), + rval=ast_internal_classes.Int_Literal_Node(value="0"), + line_number=child.line_number, + parent=child.parent + ) + ) + + for call_node in lister.call_nodes: + + alloc_nodes = self.functions_with_alloc[call_node.name.name] + + for alloc_name in alloc_nodes: + name = f'__f2dace_ALLOCATED_{alloc_name}' + call_node.args.append( + ast_internal_classes.Name_Node(name=name) + ) + + newbody.append(child) + + return ast_internal_classes.Execution_Part_Node(execution=newbody) + + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + + node.execution_part = self.visit(node.execution_part) + + args = node.args.copy() + newspec = [] + for i in node.specification_part.specifications: + + if not isinstance(i, ast_internal_classes.Decl_Stmt_Node): + newspec.append(self.visit(i)) + else: + + newdecls = [] + for var_decl in i.vardecl: + + if var_decl.alloc: + + name = f'__f2dace_ALLOCATED_{var_decl.name}' + init = ast_internal_classes.Int_Literal_Node(value="0") + + # if it's an arg, then we don't initialize + if node.name.name in self.functions_with_alloc and var_decl.name in self.functions_with_alloc[node.name.name]: + init = None + args.append( + ast_internal_classes.Name_Node(name=name) + ) + + var = ast_internal_classes.Var_Decl_Node( + name=name, + type='LOGICAL', + alloc=False, + sizes=None, + offsets=None, + kind=None, + optional=False, + init=init, + line_number=var_decl.line_number + ) + newdecls.append(var) + + if len(newdecls) > 0: + newspec.append(ast_internal_classes.Decl_Stmt_Node(vardecl=newdecls)) + + if len(newspec) > 0: + node.specification_part.specifications.append(*newspec) + + return ast_internal_classes.Subroutine_Subprogram_Node( + name=node.name, + args=args, + specification_part=node.specification_part, + execution_part=node.execution_part + ) + +def allocatableReplacer(node=ast_internal_classes.Program_Node): + + visitor = AllocatableFunctionLister() + visitor.visit(node) + + return AllocatableReplacerTransformer(visitor.functions).visit(node) + +def functionStatementEliminator(node=ast_internal_classes.Program_Node): + """ + Eliminates function statements from the AST + :param node: The AST to be transformed + :return: The transformed AST + :note Should only be used on the program node + """ + main_program = localFunctionStatementEliminator(node.main_program) + function_definitions = [localFunctionStatementEliminator(i) for i in node.function_definitions] + subroutine_definitions = [localFunctionStatementEliminator(i) for i in node.subroutine_definitions] + modules = [] + for i in node.modules: + module_function_definitions = [localFunctionStatementEliminator(j) for j in i.function_definitions] + module_subroutine_definitions = [localFunctionStatementEliminator(j) for j in i.subroutine_definitions] + modules.append( + ast_internal_classes.Module_Node( + name=i.name, + specification_part=i.specification_part, + subroutine_definitions=module_subroutine_definitions, + function_definitions=module_function_definitions, + interface_blocks=i.interface_blocks, + )) + node.main_program = main_program + node.function_definitions = function_definitions + node.subroutine_definitions = subroutine_definitions + node.modules = modules + return node + + +def localFunctionStatementEliminator(node: ast_internal_classes.FNode): + """ + Eliminates function statements from the AST + :param node: The AST to be transformed + :return: The transformed AST + """ + if node is None: + return None + if hasattr(node, "specification_part") and node.specification_part is not None: + spec = node.specification_part.specifications + else: + spec = [] + if hasattr(node, "execution_part"): + if node.execution_part is not None: + exec = node.execution_part.execution + else: + exec = [] + else: + exec = [] + new_exec = exec.copy() + to_change = [] + for i in exec: + if isinstance(i, ast_internal_classes.BinOp_Node): + if i.op == "=": + if isinstance(i.lval, ast_internal_classes.Call_Expr_Node) or isinstance( + i.lval, ast_internal_classes.Structure_Constructor_Node): + function_statement_name = i.lval.name + is_actually_function_statement = False + # In Fortran, function statement are defined as scalar values, + # but called as arrays, so by identifiying that it is called as + # a call_expr or structure_constructor, we also need to match + # the specification part and see that it is scalar rather than an array. + found = False + for j in spec: + if found: + break + for k in j.vardecl: + if k.name == function_statement_name.name: + if k.sizes is None: + is_actually_function_statement = True + function_statement_type = k.type + j.vardecl.remove(k) + found = True + break + if is_actually_function_statement: + to_change.append([i.lval, i.rval]) + new_exec.remove(i) + + else: + # There are no function statements after the first one that isn't a function statement break still_changing = True while still_changing: @@ -670,7 +1990,7 @@ def localFunctionStatementEliminator(node: ast_internal_classes.FNode): if k.name == j[0].name: calls_to_replace = FindFunctionCalls() calls_to_replace.visit(j[1]) - #must check if it is recursive and contains other function statements + # must check if it is recursive and contains other function statements it_is_simple = True for l in calls_to_replace.nodes: for m in to_change: @@ -682,8 +2002,18 @@ def localFunctionStatementEliminator(node: ast_internal_classes.FNode): final_exec = [] for i in new_exec: final_exec.append(ReplaceFunctionStatementPass(to_change).visit(i)) - node.execution_part.execution = final_exec - node.specification_part.specifications = spec + if hasattr(node, "execution_part"): + if node.execution_part is not None: + node.execution_part.execution = final_exec + else: + node.execution_part = ast_internal_classes.Execution_Part_Node(execution=final_exec) + else: + node.execution_part = ast_internal_classes.Execution_Part_Node(execution=final_exec) + # node.execution_part.execution = final_exec + if hasattr(node, "specification_part"): + if node.specification_part is not None: + node.specification_part.specifications = spec + # node.specification_part.specifications = spec return node @@ -691,6 +2021,7 @@ class ArrayLoopNodeLister(NodeVisitor): """ Finds all array operations that have to be transformed to loops in the AST """ + def __init__(self): self.nodes: List[ast_internal_classes.FNode] = [] self.range_nodes: List[ast_internal_classes.FNode] = [] @@ -698,22 +2029,11 @@ def __init__(self): def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): rval_pardecls = [i for i in mywalk(node.rval) if isinstance(i, ast_internal_classes.ParDecl_Node)] lval_pardecls = [i for i in mywalk(node.lval) if isinstance(i, ast_internal_classes.ParDecl_Node)] - if len(lval_pardecls) > 0: - if len(rval_pardecls) == 1: - self.range_nodes.append(node) - self.nodes.append(node) - return - elif len(rval_pardecls) > 1: - for i in rval_pardecls: - if i != rval_pardecls[0]: - raise NotImplementedError("Only supporting one range in right expression") - - self.range_nodes.append(node) - self.nodes.append(node) - return - else: - self.nodes.append(node) - return + if not lval_pardecls: + return + if rval_pardecls: + self.range_nodes.append(node) + self.nodes.append(node) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return @@ -721,12 +2041,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, ranges: list, - rangepos: list, rangeslen: list, count: int, newbody: list, scope_vars: ScopeVarsDeclarations, - declaration=True): + structures: Structures, + declaration=True, + main_iterator_ranges: Optional[list] = None, + allow_scalars = False + ): """ Helper function for the transformation of array operations and sums to loops :param node: The AST to be transformed @@ -735,51 +2058,132 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, :param rangepos: The positions of the ranges :param count: The current count of the loop :param newbody: The new basic block that will contain the loop - :param declaration: Whether the declaration of the loop variable is needed - :param is_sum_to_loop: Whether the transformation is for a sum to loop + :param main_iterator_ranges: When parsing right-hand side of equation, use access to main loop range :return: Ranges, rangepos, newbody """ + rangepos = [] currentindex = 0 indices = [] + name_chain = [] + if isinstance(node, ast_internal_classes.Data_Ref_Node): + + # we assume starting from the top (left-most) data_ref_node + # for struct1 % struct2 % struct3 % var + # we find definition of struct1, then we iterate until we find the var + + struct_type = scope_vars.get_var(node.parent, node.parent_ref.name).type + struct_def = structures.structures[struct_type] + cur_node = node + name_chain = [cur_node.parent_ref] + while True: + cur_node = cur_node.part_ref + if isinstance(cur_node, ast_internal_classes.Data_Ref_Node): + name_chain.append(cur_node.parent_ref) + + if isinstance(cur_node, ast_internal_classes.Array_Subscript_Node): + struct_def = structures.structures[struct_type] + offsets = struct_def.vars[cur_node.name.name].offsets + node = cur_node + break - offsets = scope_vars.get_var(node.parent, node.name.name).offsets + elif isinstance(cur_node, ast_internal_classes.Name_Node): + struct_def = structures.structures[struct_type] + + var_def = struct_def.vars[cur_node.name] + offsets = var_def.offsets + + # FIXME: is this always a desired behavior? + + # if we are passed a name node in the context of parDeclRange, + # then we assume it should be a total range across the entire array + array_sizes = var_def.sizes + assert array_sizes is not None + + dims = len(array_sizes) + node = ast_internal_classes.Array_Subscript_Node( + name=cur_node, parent=node.parent, type=var_def.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims + ) + + break + + struct_type = struct_def.vars[cur_node.parent_ref.name].type + struct_def = structures.structures[struct_type] + + else: + offsets = scope_vars.get_var(node.parent, node.name.name).offsets for idx, i in enumerate(node.indices): + if isinstance(i, ast_internal_classes.ParDecl_Node): if i.type == "ALL": - lower_boundary = None if offsets[idx] != 1: - lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx])) + # support symbols and integer literals + if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance(offset,ast_internal_classes.BinOp_Node): + lower_boundary = offsets[idx] + else: + #check if offset is a number + try: + offset_value = int(offsets[idx]) + except: + raise ValueError(f"Offset {offsets[idx]} is not a number") + lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offset_value)) else: lower_boundary = ast_internal_classes.Int_Literal_Node(value="1") + first = True + if len(name_chain) >= 1: + for i in name_chain: + if first: + first = False + array_name = i.name + else: + array_name = array_name + "_" + i.name + array_name = array_name + "_" + node.name.name + else: + array_name = node.name.name upper_boundary = ast_internal_classes.Name_Range_Node(name="f2dace_MAX", - type="INTEGER", - arrname=node.name, - pos=currentindex) + type="INTEGER", + arrname=ast_internal_classes.Name_Node( + name=array_name, type="VOID", + line_number=node.line_number), + pos=idx) """ When there's an offset, we add MAX_RANGE + offset. But since the generated loop has `<=` condition, we need to subtract 1. """ if offsets[idx] != 1: + + # support symbols and integer literals + if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance(offset,ast_internal_classes.BinOp_Node): + offset = offsets[idx] + else: + try: + offset_value = int(offsets[idx]) + except: + raise ValueError(f"Offset {offsets[idx]} is not a number") + offset = ast_internal_classes.Int_Literal_Node(value=str(offset_value)) + upper_boundary = ast_internal_classes.BinOp_Node( lval=upper_boundary, op="+", - rval=ast_internal_classes.Int_Literal_Node(value=str(offsets[idx])) + rval=offset ) upper_boundary = ast_internal_classes.BinOp_Node( lval=upper_boundary, op="-", rval=ast_internal_classes.Int_Literal_Node(value="1") ) + ranges.append([lower_boundary, upper_boundary]) rangeslen.append(-1) else: ranges.append([i.range[0], i.range[1]]) + lower_boundary = i.range[0] start = 0 if isinstance(i.range[0], ast_internal_classes.Int_Literal_Node): @@ -793,18 +2197,77 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, else: end = i.range[1] - rangeslen.append(end - start + 1) + if isinstance(end, int) and isinstance(start, int): + rangeslen.append(end - start + 1) + else: + add = ast_internal_classes.BinOp_Node( + lval=start, + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1") + ) + substr = ast_internal_classes.BinOp_Node( + lval=end, + op="-", + rval=add + ) + rangeslen.append(substr) + rangepos.append(currentindex) if declaration: newbody.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ ast_internal_classes.Symbol_Decl_Node( - name="tmp_parfor_" + str(count + len(rangepos) - 1), type="INTEGER", sizes=None, init=None) + name="tmp_parfor_" + str(count + len(rangepos) - 1), type="INTEGER", sizes=None, init=None,parent=node.parent, line_number=node.line_number) ])) - indices.append(ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1))) + + """ + To account for ranges with different starting offsets inside the same loop, + we need to adapt array accesses. + The main loop iterator is already initialized with the lower boundary of the dominating array. + + Thus, if the offset is the same, the index is just "tmp_parfor". + Otherwise, it is "tmp_parfor - tmp_parfor_lower_boundary + our_lower_boundary" + """ + + if declaration: + """ + For LHS, we don't need to adjust - we dictate the loop iterator. + """ + + indices.append( + ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1)) + ) + else: + + """ + For RHS, we adjust starting array position by taking consideration the initial value + of the loop iterator. + + Offset is handled by always subtracting the lower boundary. + """ + current_lower_boundary = main_iterator_ranges[currentindex][0] + + indices.append( + ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1)), + op="+", + rval=ast_internal_classes.BinOp_Node( + lval=lower_boundary, + op="-", + rval=current_lower_boundary,parent=node.parent + ),parent=node.parent + ) + ) + currentindex += 1 + + elif allow_scalars: + + ranges.append([i, i]) + rangeslen.append(1) + indices.append(i) + currentindex += 1 else: indices.append(i) - currentindex += 1 node.indices = indices @@ -813,11 +2276,13 @@ class ArrayToLoop(NodeTransformer): """ Transforms the AST by removing array expressions and replacing them with loops """ + def __init__(self, ast): self.count = 0 + self.ast = ast ParentScopeAssigner().visit(ast) - self.scope_vars = ScopeVarsDeclarations() + self.scope_vars = ScopeVarsDeclarations(ast) self.scope_vars.visit(ast) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): @@ -827,38 +2292,56 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No lister.visit(child) res = lister.nodes res_range = lister.range_nodes - if res is not None and len(res) > 0: + #Transpose breaks Array to loop transformation, and fixing it is not trivial - and will likely not involve array to loop at all. + calls=[i for i in mywalk(child) if isinstance(i, ast_internal_classes.Call_Expr_Node)] + skip_because_of_transpose = False + for i in calls: + if "__dace_transpose" in i.name.name.lower(): + skip_because_of_transpose = True + if skip_because_of_transpose: + newbody.append(child) + continue + try: + if res is not None and len(res) > 0: + current = child.lval - val = child.rval ranges = [] - rangepos = [] - par_Decl_Range_Finder(current, ranges, rangepos, [], self.count, newbody, self.scope_vars, True) - - if res_range is not None and len(res_range) > 0: - rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] - for i in rvals: - rangeposrval = [] - rangesrval = [] - - par_Decl_Range_Finder(i, rangesrval, rangeposrval, [], self.count, newbody, self.scope_vars, False) - - for i, j in zip(ranges, rangesrval): - if i != j: - if isinstance(i, list) and isinstance(j, list) and len(i) == len(j): - for k, l in zip(i, j): - if k != l: - if isinstance(k, ast_internal_classes.Name_Range_Node) and isinstance( - l, ast_internal_classes.Name_Range_Node): - if k.name != l.name: - raise NotImplementedError("Ranges must be the same") - else: + par_Decl_Range_Finder(current, ranges, [], self.count, newbody, self.scope_vars, + self.ast.structures, True) + + # if res_range is not None and len(res_range) > 0: + + # catch cases where an array is used as name, without range expression + visitor = ReplaceImplicitParDecls(self.scope_vars) + child.rval = visitor.visit(child.rval) + + rvals = [i for i in mywalk(child.rval) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] + for i in rvals: + rangesrval = [] + + par_Decl_Range_Finder(i, rangesrval, [], self.count, newbody, self.scope_vars, + self.ast.structures, False, ranges) + for i, j in zip(ranges, rangesrval): + if i != j: + if isinstance(i, list) and isinstance(j, list) and len(i) == len(j): + for k, l in zip(i, j): + if k != l: + if isinstance(k, ast_internal_classes.Name_Range_Node) and isinstance( + l, ast_internal_classes.Name_Range_Node): + if k.name != l.name: raise NotImplementedError("Ranges must be the same") - else: - raise NotImplementedError("Ranges must be identical") + else: + # this is not actually illegal. + # raise NotImplementedError("Ranges must be the same") + continue + else: + raise NotImplementedError("Ranges must be identical") range_index = 0 - body = ast_internal_classes.BinOp_Node(lval=current, op="=", rval=val, line_number=child.line_number) + body = ast_internal_classes.BinOp_Node(lval=current, op="=", rval=child.rval, + line_number=child.line_number,parent=child.parent) + for i in ranges: initrange = i[0] finalrange = i[1] @@ -866,34 +2349,37 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), op="=", rval=initrange, - line_number=child.line_number) + line_number=child.line_number,parent=child.parent) cond = ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), op="<=", rval=finalrange, - line_number=child.line_number) + line_number=child.line_number,parent=child.parent) iter = ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), op="=", rval=ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1")), - line_number=child.line_number) + rval=ast_internal_classes.Int_Literal_Node(value="1"),parent=child.parent), + line_number=child.line_number,parent=child.parent) current_for = ast_internal_classes.Map_Stmt_Node( init=init, cond=cond, iter=iter, body=ast_internal_classes.Execution_Part_Node(execution=[body]), - line_number=child.line_number) + line_number=child.line_number,parent=child.parent) body = current_for range_index += 1 newbody.append(body) self.count = self.count + range_index - else: + else: newbody.append(self.visit(child)) + except Exception as e: + print("Error in ArrayToLoop, exception caught at line: "+str(child.line_number)) + newbody.append(child) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -910,6 +2396,7 @@ def mywalk(node): todo.extend(iter_child_nodes(node)) yield node + class RenameVar(NodeTransformer): def __init__(self, oldname: str, newname: str): self.oldname = oldname @@ -919,10 +2406,76 @@ def visit_Name_Node(self, node: ast_internal_classes.Name_Node): return ast_internal_classes.Name_Node(name=self.newname) if node.name == self.oldname else node +class PartialRenameVar(NodeTransformer): + def __init__(self, oldname: str, newname: str): + self.oldname = oldname + self.newname = newname + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + if hasattr(node, "type"): + return ast_internal_classes.Name_Node(name=node.name.replace(self.oldname, self.newname), + parent=node.parent, + type=node.type) if self.oldname in node.name else node + else: + type = "VOID" + return ast_internal_classes.Name_Node(name=node.name.replace(self.oldname, self.newname), + parent=node.parent, + type="VOID") if self.oldname in node.name else node + + def visit_Symbol_Decl_Node(self, node: ast_internal_classes.Symbol_Decl_Node): + return ast_internal_classes.Symbol_Decl_Node(name=node.name.replace(self.oldname, self.newname), type=node.type, + sizes=node.sizes, init=node.init, line_number=node.line_number, + kind=node.kind, alloc=node.alloc, offsets=node.offsets) + + +class IfConditionExtractor(NodeTransformer): + """ + Ensures that each loop iterator is unique by extracting the actual iterator and assigning it to a uniquely named local variable + """ + + def __init__(self): + self.count = 0 + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + newbody = [] + for child in node.execution: + + if isinstance(child, ast_internal_classes.If_Stmt_Node): + old_cond = child.cond + newbody.append( + ast_internal_classes.Decl_Stmt_Node(vardecl=[ + ast_internal_classes.Var_Decl_Node( + name="_if_cond_" + str(self.count), type="INTEGER", sizes=None, init=None) + ])) + newbody.append(ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_if_cond_" + str(self.count)), + op="=", + rval=old_cond, + line_number=child.line_number, + parent=child.parent)) + newcond = ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="_if_cond_" + str(self.count)), + op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number,parent=old_cond.parent) + newifbody = self.visit(child.body) + newelsebody = self.visit(child.body_else) + + newif = ast_internal_classes.If_Stmt_Node(cond=newcond, body=newifbody, body_else=newelsebody, + line_number=child.line_number, parent=child.parent) + self.count += 1 + + newbody.append(newif) + + else: + newbody.append(self.visit(child)) + return ast_internal_classes.Execution_Part_Node(execution=newbody) + + class ForDeclarer(NodeTransformer): """ Ensures that each loop iterator is unique by extracting the actual iterator and assigning it to a uniquely named local variable """ + def __init__(self): self.count = 0 @@ -941,8 +2494,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No final_assign = ast_internal_classes.BinOp_Node(lval=child.init.lval, op="=", rval=child.cond.rval, - line_number=child.line_number) - newfor = RenameVar(child.init.lval.name, "_for_it_" + str(self.count)).visit(child) + line_number=child.line_number,parent=child.parent) + newfbody = RenameVar(child.init.lval.name, "_for_it_" + str(self.count)).visit(child.body) + newcond = RenameVar(child.cond.lval.name, "_for_it_" + str(self.count)).visit(child.cond) + newiter = RenameVar(child.iter.lval.name, "_for_it_" + str(self.count)).visit(child.iter) + newinit = child.init + newinit.lval = RenameVar(child.init.lval.name, "_for_it_" + str(self.count)).visit(child.init.lval) + + newfor = ast_internal_classes.For_Stmt_Node(init=newinit, cond=newcond, iter=newiter, body=newfbody, + line_number=child.line_number, parent=child.parent) self.count += 1 newfor = self.visit(newfor) newbody.append(newfor) @@ -950,3 +2510,928 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No else: newbody.append(self.visit(child)) return ast_internal_classes.Execution_Part_Node(execution=newbody) + + +class ElementalFunctionExpander(NodeTransformer): + "Makes elemental functions into normal functions by creating a loop around thme if they are called with arrays" + + def __init__(self, func_list: list): + self.func_list = func_list + self.count = 0 + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + newbody = [] + for child in node.execution: + if isinstance(child, ast_internal_classes.Call_Expr_Node): + arrays = False + for i in self.func_list: + if child.name.name == i.name or child.name.name == i.name + "_srt": + if hasattr(i, "elemental"): + if i.elemental is True: + if len(child.args) > 0: + for j in child.args: + # THIS Needs a proper check + if j.name == "z": + arrays = True + + if not arrays: + newbody.append(self.visit(child)) + else: + newbody.append( + ast_internal_classes.Decl_Stmt_Node(vardecl=[ + ast_internal_classes.Symbol_Decl_Node( + name="_for_elem_it_" + str(self.count), type="INTEGER", sizes=None, init=None) + ])) + newargs = [] + # The range must be determined! It's currently hard set to 10 + shape = ["10"] + for i in child.args: + if isinstance(i, ast_internal_classes.Name_Node): + newargs.append(ast_internal_classes.Array_Subscript_Node(name=i, indices=[ + ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count))], + line_number=child.line_number, + type=i.type)) + if i.name.startswith("tmp_call_"): + for j in newbody: + if isinstance(j, ast_internal_classes.Decl_Stmt_Node): + if j.vardecl[0].name == i.name: + newbody[newbody.index(j)].vardecl[0].sizes = shape + break + else: + raise NotImplementedError("Only name nodes are supported") + + newbody.append(ast_internal_classes.For_Stmt_Node( + init=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number,parent=child.parent), + cond=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="<=", + rval=ast_internal_classes.Int_Literal_Node(value=shape[0]), + line_number=child.line_number,parent=child.parent), + body=ast_internal_classes.Execution_Part_Node(execution=[ + ast_internal_classes.Call_Expr_Node(type=child.type, + name=child.name, + args=newargs, + line_number=child.line_number,parent=child.parent) + ]), line_number=child.line_number, + iter=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1"),parent=child.parent), + line_number=child.line_number,parent=child.parent) + )) + self.count += 1 + + + else: + newbody.append(self.visit(child)) + return ast_internal_classes.Execution_Part_Node(execution=newbody) + + +class TypeInference(NodeTransformer): + """ + """ + + def __init__(self, ast, assert_voids=True, assign_scopes=True, scope_vars = None): + self.assert_voids = assert_voids + + self.ast = ast + if assign_scopes: + ParentScopeAssigner().visit(ast) + if scope_vars is None: + self.scope_vars = ScopeVarsDeclarations(ast) + self.scope_vars.visit(ast) + else: + self.scope_vars = scope_vars + self.structures = ast.structures + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + + if not hasattr(node, 'type') or node.type == 'VOID' or not hasattr(node, 'dims'): + try: + var_def = self.scope_vars.get_var(node.parent, node.name) + if var_def.type != 'VOID': + node.type = var_def.type + node.dims = len(var_def.sizes) if hasattr(var_def, 'sizes') and var_def.sizes is not None else 1 + except Exception as e: + print(f"Ignore type inference for {node.name}") + print(e) + + return node + + def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): + + var_def = self.scope_vars.get_var(node.parent, node.name.name) + if var_def.type != 'VOID': + node.type = var_def.type + node.dims = len(var_def.sizes) if var_def.sizes is not None else 1 + return node + + def visit_Parenthesis_Expr_Node(self, node: ast_internal_classes.Parenthesis_Expr_Node): + + node.expr = self.visit(node.expr) + if node.expr.type != 'VOID': + node.type = node.expr.type + if hasattr(node.expr, 'dims'): + node.dims = node.expr.dims + return node + + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): + + """ + Simple implementation of type promotion in binary ops. + """ + + node.lval = self.visit(node.lval) + node.rval = self.visit(node.rval) + + type_hierarchy = [ + 'VOID', + 'LOGICAL', + 'CHAR', + 'INTEGER', + 'REAL', + 'DOUBLE' + ] + + idx_left = type_hierarchy.index(self._get_type(node.lval)) + idx_right = type_hierarchy.index(self._get_type(node.rval)) + idx_void = type_hierarchy.index('VOID') + + # if self.assert_voids: + # assert idx_left != idx_void or idx_right != idx_void + # #assert self._get_dims(node.lval) == self._get_dims(node.rval) + + node.type = type_hierarchy[max(idx_left, idx_right)] + if hasattr(node.lval, "dims"): + node.dims = self._get_dims(node.lval) + elif hasattr(node.lval, "dims"): + node.dims = self._get_dims(node.rval) + + if node.op == '=' and idx_left == idx_void and idx_right != idx_void: + lval_definition = self.scope_vars.get_var(node.parent, node.lval.name) + lval_definition.type = node.type + lval_definition.dims = node.dims + node.lval.type = node.type + node.lval.dims = node.dims + + if node.type == 'VOID': + print("Couldn't infer the type for binop!") + + return node + + def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + + if node.type != 'VOID': + return node + + struct, variable = self.structures.find_definition( + self.scope_vars, node + ) + if variable.type != 'VOID': + node.type = variable.type + node.dims = len(variable.sizes) if variable.sizes is not None else 1 + return node + + def visit_Actual_Arg_Spec_Node(self, node: ast_internal_classes.Actual_Arg_Spec_Node): + + if node.type != 'VOID': + return node + + node.arg = self.visit(node.arg) + + func_arg_name_type = self._get_type(node.arg) + if func_arg_name_type == 'VOID': + + func_arg = self.scope_vars.get_var(node.parent, node.arg.name) + node.type = func_arg.type + node.arg.type = func_arg.type + dims = len(func_arg.sizes) if func_arg.sizes is not None else 1 + node.dims = dims + node.arg.dims = dims + + else: + node.type = func_arg_name_type + node.dims = self._get_dims(node.arg) + + return node + + def visit_UnOp_Node(self, node: ast_internal_classes.UnOp_Node): + node.lval = self.visit(node.lval) + if node.lval.type != 'VOID': + node.type = node.lval.type + return node + + def _get_type(self, node): + + if isinstance(node, ast_internal_classes.Int_Literal_Node): + return 'INTEGER' + elif isinstance(node, ast_internal_classes.Real_Literal_Node): + return 'REAL' + elif isinstance(node, ast_internal_classes.Bool_Literal_Node): + return 'LOGICAL' + else: + return node.type + + def _get_dims(self, node): + + if isinstance(node, ast_internal_classes.Int_Literal_Node): + return 1 + elif isinstance(node, ast_internal_classes.Real_Literal_Node): + return 1 + elif isinstance(node, ast_internal_classes.Bool_Literal_Node): + return 1 + else: + return node.dims + + +class ReplaceInterfaceBlocks(NodeTransformer): + """ + """ + + def __init__(self, program, funcs: FindFunctionAndSubroutines): + self.funcs = funcs + + ParentScopeAssigner().visit(program) + self.scope_vars = ScopeVarsDeclarations(program) + self.scope_vars.visit(program) + + def _get_dims(self, node): + + if hasattr(node, "dims"): + return node.dims + + if isinstance(node, ast_internal_classes.Var_Decl_Node): + return len(node.sizes) if node.sizes is not None else 1 + + raise RuntimeError() + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + # is_func = node.name.name in self.excepted_funcs or node.name in self.funcs.names + # is_interface_func = not node.name in self.funcs.names and node.name.name in self.funcs.iblocks + is_interface_func = node.name.name in self.funcs.iblocks + + if is_interface_func: + + available_names = [] + print("Invoke", node.name.name, available_names) + for name in self.funcs.iblocks[node.name.name]: + + # non_optional_args = len(self.funcs.nodes[name].args) - self.funcs.nodes[name].optional_args_count + non_optional_args = self.funcs.nodes[name].mandatory_args_count + print("Check", name, non_optional_args, self.funcs.nodes[name].optional_args_count) + + success = True + for call_arg, func_arg in zip(node.args[0:non_optional_args], + self.funcs.nodes[name].args[0:non_optional_args]): + print("Mandatory arg", call_arg, type(call_arg)) + if call_arg.type != func_arg.type or self._get_dims(call_arg) != self._get_dims(func_arg): + print(f"Ignore function {name}, wrong param type {call_arg.type} instead of {func_arg.type}") + success = False + break + else: + print(self._get_dims(call_arg), self._get_dims(func_arg), type(call_arg), call_arg.type, + func_arg.name, type(func_arg), func_arg.type) + + optional_args = self.funcs.nodes[name].args[non_optional_args:] + pos = non_optional_args + for idx, call_arg in enumerate(node.args[non_optional_args:]): + + print("Optional arg", call_arg, type(call_arg)) + if isinstance(call_arg, ast_internal_classes.Actual_Arg_Spec_Node): + func_arg_name = call_arg.arg_name + try: + func_arg = self.scope_vars.get_var(name, func_arg_name.name) + except: + # this keyword parameter is not available in this function + success = False + break + print('btw', func_arg, type(func_arg), func_arg.type) + else: + func_arg = optional_args[idx] + + # if call_arg.type != func_arg.type: + if call_arg.type != func_arg.type or self._get_dims(call_arg) != self._get_dims(func_arg): + print(f"Ignore function {name}, wrong param type {call_arg.type} instead of {func_arg.type}") + success = False + break + else: + print(self._get_dims(call_arg), self._get_dims(func_arg), type(call_arg), call_arg.type, + func_arg.name, type(func_arg), func_arg.type) + + if success: + available_names.append(name) + + if len(available_names) == 0: + raise RuntimeError("No matching function calls!") + + if len(available_names) != 1: + print(node.name.name, available_names) + raise RuntimeError("Too many matching function calls!") + + print(f"Selected {available_names[0]} as invocation for {node.name}") + node.name = ast_internal_classes.Name_Node(name=available_names[0]) + + return node + + +class PointerRemoval(NodeTransformer): + + def __init__(self): + self.nodes = {} + + def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): + + if node.name.name in self.nodes: + original_ref_node = self.nodes[node.name.name] + + cur_ref_node = original_ref_node + new_ref_node = ast_internal_classes.Data_Ref_Node( + parent_ref=cur_ref_node.parent_ref, + part_ref=None + ) + newer_ref_node = new_ref_node + + while isinstance(cur_ref_node.part_ref, ast_internal_classes.Data_Ref_Node): + cur_ref_node = cur_ref_node.part_ref + newest_ref_node = ast_internal_classes.Data_Ref_Node( + parent_ref=cur_ref_node.parent_ref, + part_ref=None + ) + newer_ref_node.part_ref = newest_ref_node + newer_ref_node = newest_ref_node + + node.name = cur_ref_node.part_ref + newer_ref_node.part_ref = node + return new_ref_node + else: + return self.generic_visit(node) + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + + if node.name in self.nodes: + return self.nodes[node.name] + return node + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + newbody = [] + + for child in node.execution: + + if isinstance(child, ast_internal_classes.Pointer_Assignment_Stmt_Node): + self.nodes[child.name_pointer.name] = child.name_target + else: + newbody.append(self.visit(child)) + + return ast_internal_classes.Execution_Part_Node(execution=newbody) + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + + if node.execution_part is not None: + execution_part = self.visit(node.execution_part) + else: + execution_part = node.execution_part + + if node.specification_part is not None: + specification_part = self.visit(node.specification_part) + else: + specification_part = node.specification_part + + return ast_internal_classes.Subroutine_Subprogram_Node( + name=node.name, + args=node.args, + specification_part=specification_part, + execution_part=execution_part, + line_number=node.line_number + ) + + def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification_Part_Node): + + newspec = [] + + symbols_to_remove = set() + + for i in node.specifications: + + if not isinstance(i, ast_internal_classes.Decl_Stmt_Node): + newspec.append(self.visit(i)) + else: + + newdecls = [] + for var_decl in i.vardecl: + + if var_decl.name in self.nodes: + if var_decl.sizes is not None: + for symbol in var_decl.sizes: + symbols_to_remove.add(symbol.name) + if var_decl.offsets is not None: + for symbol in var_decl.offsets: + symbols_to_remove.add(symbol.name) + + else: + newdecls.append(var_decl) + if len(newdecls) > 0: + newspec.append(ast_internal_classes.Decl_Stmt_Node(vardecl=newdecls)) + + if node.symbols is not None: + new_symbols = [] + for symbol in node.symbols: + if symbol.name not in symbols_to_remove: + new_symbols.append(symbol) + else: + new_symbols = None + + return ast_internal_classes.Specification_Part_Node( + specifications=newspec, + symbols=new_symbols, + typedecls=node.typedecls, + uses=node.uses, + enums=node.enums + ) + + +class ArgumentPruner(NodeVisitor): + + def __init__(self, funcs): + + self.funcs = funcs + + self.parsed_funcs: Dict[str, List[int]] = {} + + self.used_names = set() + self.declaration_names = set() + + self.used_in_all_functions: Set[str] = set() + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + # if node.name not in self.used_names: + # print(f"Used name {node.name}") + self.used_names.add(node.name) + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + self.declaration_names.add(node.name) + + # visit also sizes and offsets + self.generic_visit(node) + + def generic_visit(self, node: ast_internal_classes.FNode): + """Called if no explicit visitor function exists for a node.""" + for field, value in iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value) + + for field, value in iter_attributes(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value) + + def _visit_function(self, node: ast_internal_classes.FNode): + + old_used_names = self.used_names + self.used_names = set() + self.declaration_names = set() + + self.visit(node.specification_part) + + self.visit(node.execution_part) + + new_args = [] + removed_args = [] + for idx, arg in enumerate(node.args): + + if not isinstance(arg, ast_internal_classes.Name_Node): + raise NotImplementedError() + + if arg.name not in self.used_names: + # print(f"Pruning argument {arg.name} of function {node.name.name}") + removed_args.append(idx) + else: + # print(f"Leaving used argument {arg.name} of function {node.name.name}") + new_args.append(arg) + self.parsed_funcs[node.name.name] = removed_args + + declarations_to_remove = set() + for x in self.declaration_names: + if x not in self.used_names: + # print(f"Marking removal variable {x}") + declarations_to_remove.add(x) + # else: + # print(f"Keeping used variable {x}") + + for decl_stmt_node in node.specification_part.specifications: + + newdecl = [] + for decl in decl_stmt_node.vardecl: + + if not isinstance(decl, ast_internal_classes.Var_Decl_Node): + raise NotImplementedError() + + if decl.name not in declarations_to_remove: + # print(f"Readding declared variable {decl.name}") + newdecl.append(decl) + # else: + # print(f"Pruning unused but declared variable {decl.name}") + decl_stmt_node.vardecl = newdecl + + self.used_in_all_functions.update(self.used_names) + self.used_names = old_used_names + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + + if node.name.name not in self.parsed_funcs: + self._visit_function(node) + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + # print(f"Prune argument {node.args[idx].name} in {node.name.name}") + del node.args[idx] + + def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Subprogram_Node): + + if node.name.name not in self.parsed_funcs: + self._visit_function(node) + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + del node.args[idx] + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + if node.name.name not in self.parsed_funcs: + + if node.name.name in self.funcs: + self._visit_function(self.funcs[node.name.name]) + else: + + # now add actual arguments to the list of used names + for arg in node.args: + self.visit(arg) + + return + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + del node.args[idx] + + # now add actual arguments to the list of used names + for arg in node.args: + self.visit(arg) + + +class PropagateEnums(NodeTransformer): + """ + """ + + def __init__(self): + self.parsed_enums = {} + + def _parse_enums(self, enums): + + for j in enums: + running_count = 0 + for k in j: + if isinstance(k, list): + for l in k: + if isinstance(l, ast_internal_classes.Name_Node): + self.parsed_enums[l.name] = running_count + running_count += 1 + elif isinstance(l, list): + self.parsed_enums[l[0].name] = l[2].value + running_count = int(l[2].value) + 1 + else: + + raise ValueError("Unknown enum type") + else: + raise ValueError("Unknown enum type") + + def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification_Part_Node): + self._parse_enums(node.enums) + return self.generic_visit(node) + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + + if self.parsed_enums.get(node.name) is not None: + node.type = 'INTEGER' + return ast_internal_classes.Int_Literal_Node(value=str(self.parsed_enums[node.name])) + + return node + + +class IfEvaluator(NodeTransformer): + def __init__(self): + self.replacements = 0 + + def visit_If_Stmt_Node(self, node): + try: + text = ast_utils.TaskletWriter({}, {}).write_code(node.cond) + except: + text = None + return self.generic_visit(node) + # print(text) + try: + evaluated = sym.evaluate(sym.pystr_to_symbolic(text), {}) + except: + # print("Failed: " + text) + return self.generic_visit(node) + + if evaluated == sp.true: + print("Expr: " + text + " eval to True replace") + self.replacements += 1 + return node.body + elif evaluated == sp.false: + print("Expr: " + text + " eval to False replace") + self.replacements += 1 + return node.body_else + + return self.generic_visit(node) + + +class AssignmentLister(NodeTransformer): + def __init__(self, correction=[]): + self.simple_assignments = [] + self.correction = correction + + def reset(self): + self.simple_assignments = [] + + def visit_BinOp_Node(self, node): + if node.op == "=": + if isinstance(node.lval, ast_internal_classes.Name_Node): + for i in self.correction: + if node.lval.name == i[0]: + node.rval.value = i[1] + self.simple_assignments.append((node.lval, node.rval)) + return node + + +class AssignmentPropagator(NodeTransformer): + def __init__(self, simple_assignments): + self.simple_assignments = simple_assignments + self.replacements = 0 + + def visit_If_Stmt_Node(self, node): + test = self.generic_visit(node) + return ast_internal_classes.If_Stmt_Node(line_number=node.line_number, cond=test.cond, body=test.body, + body_else=test.body_else) + + def generic_visit(self, node: ast_internal_classes.FNode): + for field, old_value in iter_fields(node): + if isinstance(old_value, list): + new_values = [] + for value in old_value: + if isinstance(value, ast_internal_classes.FNode): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ast_internal_classes.FNode): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + elif isinstance(old_value, ast_internal_classes.FNode): + done = False + if isinstance(node, ast_internal_classes.BinOp_Node): + if node.op == "=": + if old_value == node.lval: + new_node = self.visit(old_value) + done = True + if not done: + for i in self.simple_assignments: + if old_value == i[0]: + old_value = i[1] + self.replacements += 1 + break + elif (isinstance(old_value, ast_internal_classes.Name_Node) + and isinstance(i[0], ast_internal_classes.Name_Node)): + if old_value.name == i[0].name: + old_value = i[1] + self.replacements += 1 + break + elif (isinstance(old_value, ast_internal_classes.Data_Ref_Node) + and isinstance(i[0], ast_internal_classes.Data_Ref_Node)): + if (isinstance(old_value.part_ref, ast_internal_classes.Name_Node) + and isinstance(i[0].part_ref, ast_internal_classes.Name_Node) + and isinstance(old_value.parent_ref, ast_internal_classes.Name_Node) + and isinstance(i[0].parent_ref, ast_internal_classes.Name_Node)): + if (old_value.part_ref.name == i[0].part_ref.name + and old_value.parent_ref.name == i[0].parent_ref.name): + old_value = i[1] + self.replacements += 1 + break + + new_node = self.visit(old_value) + + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node + + +class getCalls(NodeVisitor): + def __init__(self): + self.calls = [] + + def visit_Call_Expr_Node(self, node): + self.calls.append(node.name.name) + for arg in node.args: + self.visit(arg) + return + + +class FindUnusedFunctions(NodeVisitor): + def __init__(self, root, parse_order): + self.root = root + self.parse_order = parse_order + self.used_names = {} + + def visit_Subroutine_Subprogram_Node(self, node): + getacall = getCalls() + getacall.visit(node.execution_part) + used_calls = getacall.calls + self.used_names[node.name.name] = used_calls + return + + +class ReplaceImplicitParDecls(NodeTransformer): + + def __init__(self, scope_vars): + self.scope_vars = scope_vars + + def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_Node): + return node + + def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): + return node + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + + var = self.scope_vars.get_var(node.parent, node.name) + if var.sizes is not None: + + indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) + return ast_internal_classes.Array_Subscript_Node( + name=node, + type=var.type, + parent=node.parent, + indices=indices, + line_number=node.line_number + ) + else: + return node + +class ReplaceStructArgsLibraryNodesVisitor(NodeVisitor): + """ + Finds all intrinsic operations that have to be transformed to loops in the AST + """ + + def __init__(self): + self.nodes: List[ast_internal_classes.FNode] = [] + + self.FUNCS_TO_REPLACE = [ + "transpose", + "matmul" + ] + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + name = node.name.name.split('__dace_') + if len(name) == 2 and name[1].lower() in self.FUNCS_TO_REPLACE: + self.nodes.append(node) + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + return + +class ReplaceStructArgsLibraryNodes(NodeTransformer): + + def __init__(self, ast): + + self.ast = ast + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations(ast) + self.scope_vars.visit(ast) + self.structures = ast.structures + + self.counter = 0 + + FUNCS_TO_REPLACE = [ + "transpose", + "matmul" + ] + + # FIXME: copy-paste from intrinsics + def _parse_struct_ref(self, node: ast_internal_classes.Data_Ref_Node) -> ast_internal_classes.FNode: + + # we assume starting from the top (left-most) data_ref_node + # for struct1 % struct2 % struct3 % var + # we find definition of struct1, then we iterate until we find the var + + struct_type = self.scope_vars.get_var(node.parent, node.parent_ref.name).type + struct_def = self.ast.structures.structures[struct_type] + cur_node = node + + while True: + cur_node = cur_node.part_ref + + if isinstance(cur_node, ast_internal_classes.Array_Subscript_Node): + struct_def = self.ast.structures.structures[struct_type] + return struct_def.vars[cur_node.name.name] + + elif isinstance(cur_node, ast_internal_classes.Name_Node): + struct_def = self.ast.structures.structures[struct_type] + return struct_def.vars[cur_node.name] + + elif isinstance(cur_node, ast_internal_classes.Data_Ref_Node): + struct_type = struct_def.vars[cur_node.parent_ref.name].type + struct_def = self.ast.structures.structures[struct_type] + + else: + raise NotImplementedError() + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + + newbody = [] + + for child in node.execution: + + lister = ReplaceStructArgsLibraryNodesVisitor() + lister.visit(child) + res = lister.nodes + + if res is None or len(res) == 0: + newbody.append(self.visit(child)) + continue + + for call_node in res: + + args = [] + for arg in call_node.args: + + if isinstance(arg, ast_internal_classes.Data_Ref_Node): + + var = self._parse_struct_ref(arg) + tmp_var_name = f"tmp_libnode_{self.counter}" + + node.parent.specification_part.specifications.append( + ast_internal_classes.Decl_Stmt_Node(vardecl=[ + ast_internal_classes.Var_Decl_Node( + name=tmp_var_name, + type=var.type, + sizes=var.sizes, + offsets=var.offsets, + init=None + ) + ]) + ) + + dest_node = ast_internal_classes.Array_Subscript_Node( + name=ast_internal_classes.Name_Node(name=tmp_var_name), + parent=call_node.parent, type=var.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) + ) + + if isinstance(arg.part_ref, ast_internal_classes.Name_Node): + + arg.part_ref = ast_internal_classes.Array_Subscript_Node( + name=arg.part_ref, + parent=call_node.parent, type=arg.part_ref.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) + ) + + newbody.append( + ast_internal_classes.BinOp_Node( + op="=", + lval=dest_node, + rval=arg, + line_number=child.line_number, + parent=child.parent + ) + ) + + self.counter += 1 + + args.append(ast_internal_classes.Name_Node(name=tmp_var_name, type=var.type)) + + else: + args.append(arg) + + call_node.args = args + + newbody.append(child) + + return ast_internal_classes.Execution_Part_Node(execution=newbody) + diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index b52bd31df7..5de91f71bb 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -1,33 +1,36 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +from itertools import chain +from typing import List, Set, Iterator, Type, TypeVar, Dict, Tuple, Iterable, Union, Optional + +import networkx as nx +from fparser.two.Fortran2003 import Module_Stmt, Name, Interface_Block, Subroutine_Stmt, Specification_Part, Module, \ + Derived_Type_Def, Function_Stmt, Interface_Stmt, Function_Body, Type_Name, Rename, Entity_Decl, Kind_Selector, \ + Intrinsic_Type_Spec, Use_Stmt, Declaration_Type_Spec +from fparser.two.Fortran2008 import Type_Declaration_Stmt, Procedure_Stmt +from fparser.two.utils import Base +from numpy import finfo as finf +from numpy import float64 as fl -from fparser.api import parse -import os -import sys -from fparser.common.readfortran import FortranStringReader, FortranFileReader - -#dace imports -from dace import subsets -from dace.data import Scalar -from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace import DebugInfo as di +from dace import Language as lang from dace import Memlet -from dace.sdfg.nodes import Tasklet +from dace import data as dat from dace import dtypes +# dace imports +from dace import subsets from dace import symbolic as sym -from dace import DebugInfo as di -from dace import Language as lang -from dace.properties import CodeBlock -from numpy import finfo as finf -from numpy import float64 as fl - from dace.frontend.fortran import ast_internal_classes -from typing import List, Set +from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace.sdfg.nodes import Tasklet fortrantypes2dacetypes = { "DOUBLE": dtypes.float64, "REAL": dtypes.float32, "INTEGER": dtypes.int32, - "BOOL": dtypes.int32, #This is a hack to allow fortran to pass through external C - + "INTEGER8": dtypes.int64, + "CHAR": dtypes.int8, + "LOGICAL": dtypes.int32, # This is a hack to allow fortran to pass through external C + "Unknown": dtypes.float64, # TMP hack unti lwe have a proper type inference } @@ -43,19 +46,40 @@ def add_tasklet(substate: SDFGState, name: str, vars_in: Set[str], vars_out: Set def add_memlet_read(substate: SDFGState, var_name: str, tasklet: Tasklet, dest_conn: str, memlet_range: str): - src = substate.add_access(var_name) + found = False + if isinstance(substate.parent.arrays[var_name], dat.View): + for i in substate.data_nodes(): + if i.data == var_name and len(substate.out_edges(i)) == 0: + src = i + found = True + break + if not found: + src = substate.add_read(var_name) + + # src = substate.add_access(var_name) if memlet_range != "": substate.add_memlet_path(src, tasklet, dst_conn=dest_conn, memlet=Memlet(expr=var_name, subset=memlet_range)) else: substate.add_memlet_path(src, tasklet, dst_conn=dest_conn, memlet=Memlet(expr=var_name)) + return src def add_memlet_write(substate: SDFGState, var_name: str, tasklet: Tasklet, source_conn: str, memlet_range: str): - dst = substate.add_write(var_name) + found = False + if isinstance(substate.parent.arrays[var_name], dat.View): + for i in substate.data_nodes(): + if i.data == var_name and len(substate.in_edges(i)) == 0: + dst = i + found = True + break + if not found: + dst = substate.add_write(var_name) + # dst = substate.add_write(var_name) if memlet_range != "": substate.add_memlet_path(tasklet, dst, src_conn=source_conn, memlet=Memlet(expr=var_name, subset=memlet_range)) else: substate.add_memlet_path(tasklet, dst, src_conn=source_conn, memlet=Memlet(expr=var_name)) + return dst def add_simple_state_to_sdfg(state: SDFGState, top_sdfg: SDFG, state_name: str): @@ -74,10 +98,25 @@ def finish_add_state_to_sdfg(state: SDFGState, top_sdfg: SDFG, substate: SDFGSta def get_name(node: ast_internal_classes.FNode): - if isinstance(node, ast_internal_classes.Name_Node): - return node.name - elif isinstance(node, ast_internal_classes.Array_Subscript_Node): - return node.name.name + if isinstance(node, ast_internal_classes.Actual_Arg_Spec_Node): + actual_node = node.arg + else: + actual_node = node + if isinstance(actual_node, ast_internal_classes.Name_Node): + return actual_node.name + elif isinstance(actual_node, ast_internal_classes.Array_Subscript_Node): + return actual_node.name.name + elif isinstance(actual_node, ast_internal_classes.Data_Ref_Node): + view_name = actual_node.parent_ref.name + while isinstance(actual_node.part_ref, ast_internal_classes.Data_Ref_Node): + if isinstance(actual_node.part_ref.parent_ref, ast_internal_classes.Name_Node): + view_name = view_name + "_" + actual_node.part_ref.parent_ref.name + elif isinstance(actual_node.part_ref.parent_ref, ast_internal_classes.Array_Subscript_Node): + view_name = view_name + "_" + actual_node.part_ref.parent_ref.name.name + actual_node = actual_node.part_ref + view_name = view_name + "_" + get_name(actual_node.part_ref) + return view_name + else: raise NameError("Name not found") @@ -93,38 +132,65 @@ class TaskletWriter: :param name_mapping: mapping of names in the code to names in the sdfg :return: python code for a tasklet, as a string """ + def __init__(self, outputs: List[str], outputs_changes: List[str], sdfg: SDFG = None, name_mapping=None, input: List[str] = None, - input_changes: List[str] = None): + input_changes: List[str] = None, + placeholders={}, + placeholders_offsets={}, + rename_dict=None + ): self.outputs = outputs self.outputs_changes = outputs_changes self.sdfg = sdfg + self.placeholders = placeholders + self.placeholders_offsets = placeholders_offsets self.mapping = name_mapping self.input = input self.input_changes = input_changes + self.rename_dict = rename_dict + self.depth = 0 self.ast_elements = { ast_internal_classes.BinOp_Node: self.binop2string, + ast_internal_classes.Actual_Arg_Spec_Node: self.actualarg2string, ast_internal_classes.Name_Node: self.name2string, ast_internal_classes.Name_Range_Node: self.name2string, ast_internal_classes.Int_Literal_Node: self.intlit2string, ast_internal_classes.Real_Literal_Node: self.floatlit2string, + ast_internal_classes.Double_Literal_Node: self.doublelit2string, ast_internal_classes.Bool_Literal_Node: self.boollit2string, + ast_internal_classes.Char_Literal_Node: self.charlit2string, ast_internal_classes.UnOp_Node: self.unop2string, ast_internal_classes.Array_Subscript_Node: self.arraysub2string, ast_internal_classes.Parenthesis_Expr_Node: self.parenthesis2string, ast_internal_classes.Call_Expr_Node: self.call2string, ast_internal_classes.ParDecl_Node: self.pardecl2string, + ast_internal_classes.Data_Ref_Node: self.dataref2string, + ast_internal_classes.Array_Constructor_Node: self.arrayconstructor2string, } def pardecl2string(self, node: ast_internal_classes.ParDecl_Node): - #At this point in the process, the should not be any ParDecl nodes left in the AST - they should have been replaced by the appropriate ranges + # At this point in the process, the should not be any ParDecl nodes left in the AST - they should have been replaced by the appropriate ranges + return '0' + #raise NameError("Error in code generation") return f"ERROR{node.type}" + def actualarg2string(self, node: ast_internal_classes.Actual_Arg_Spec_Node): + return self.write_code(node.arg) + + def arrayconstructor2string(self, node: ast_internal_classes.Array_Constructor_Node): + str_to_return = "[ " + for i in node.value_list: + str_to_return += self.write_code(i) + ", " + str_to_return = str_to_return[:-2] + str_to_return += " ]" + return str_to_return + def write_code(self, node: ast_internal_classes.FNode): """ :param node: node to write code for @@ -136,16 +202,32 @@ def write_code(self, node: ast_internal_classes.FNode): :note If it not, an error is raised """ + self.depth += 1 if node.__class__ in self.ast_elements: text = self.ast_elements[node.__class__](node) if text is None: raise NameError("Error in code generation") - + if "ERRORALL" in text and self.depth == 1: + print(text) + #raise NameError("Error in code generation") + self.depth -= 1 return text + elif isinstance(node, int): + self.depth -= 1 + return str(node) elif isinstance(node, str): + self.depth -= 1 return node + elif isinstance(node, sym.symbol): + string_name = str(node) + string_to_return = self.write_code(ast_internal_classes.Name_Node(name=string_name)) + self.depth -= 1 + return string_to_return else: - raise NameError("Error in code generation" + node.__class__.__name__) + raise NameError("Error in code generation: " + node.__class__.__name__) + + def dataref2string(self, node: ast_internal_classes.Data_Ref_Node): + return self.write_code(node.parent_ref) + "." + self.write_code(node.part_ref) def arraysub2string(self, node: ast_internal_classes.Array_Subscript_Node): str_to_return = self.write_code(node.name) + "[" + self.write_code(node.indices[0]) @@ -155,16 +237,47 @@ def arraysub2string(self, node: ast_internal_classes.Array_Subscript_Node): return str_to_return def name2string(self, node): + if isinstance(node, str): return node return_value = node.name name = node.name - for i in self.sdfg.arrays: - sdfg_name = self.mapping.get(self.sdfg).get(name) - if sdfg_name == i: - name = i - break + if hasattr(node, "isStructMember"): + if node.isStructMember: + return node.name + + if self.rename_dict is not None and str(name) in self.rename_dict: + return self.write_code(self.rename_dict[str(name)]) + if self.placeholders.get(name) is not None: + location = self.placeholders.get(name) + sdfg_name = self.mapping.get(self.sdfg).get(location[0]) + if sdfg_name is None: + return name + else: + if self.sdfg.arrays[sdfg_name].shape is None or ( + len(self.sdfg.arrays[sdfg_name].shape) == 1 and self.sdfg.arrays[sdfg_name].shape[0] == 1): + return "1" + size = self.sdfg.arrays[sdfg_name].shape[location[1]] + return self.write_code(str(size)) + + if self.placeholders_offsets.get(name) is not None: + location = self.placeholders_offsets.get(name) + sdfg_name = self.mapping.get(self.sdfg).get(location[0]) + if sdfg_name is None: + return name + else: + if self.sdfg.arrays[sdfg_name].shape is None or ( + len(self.sdfg.arrays[sdfg_name].shape) == 1 and self.sdfg.arrays[sdfg_name].shape[0] == 1): + return "0" + offset = self.sdfg.arrays[sdfg_name].offset[location[1]] + return self.write_code(str(offset)) + if self.sdfg is not None: + for i in self.sdfg.arrays: + sdfg_name = self.mapping.get(self.sdfg).get(name) + if sdfg_name == i: + name = i + break if len(self.outputs) > 0: if name == self.outputs[0]: @@ -214,6 +327,13 @@ def floatlit2string(self, node: ast_internal_classes.Real_Literal_Node): lit = lit.replace('d', 'e') return f"{float(lit)}" + def doublelit2string(self, node: ast_internal_classes.Double_Literal_Node): + + return "".join(map(str, node.value)) + + def charlit2string(self, node: ast_internal_classes.Char_Literal_Node): + return "".join(map(str, node.value)) + def boollit2string(self, node: ast_internal_classes.Bool_Literal_Node): return str(node.value) @@ -232,7 +352,7 @@ def call2string(self, node: ast_internal_classes.Call_Expr_Node): if node.name.name == "__dace_epsilon": return str(finf(fl).eps) if node.name.name == "pow": - return " ( " + self.write_code(node.args[0]) + " ** " + self.write_code(node.args[1]) + " ) " + return "( " + self.write_code(node.args[0]) + " ** " + self.write_code(node.args[1]) + " )" return_str = self.write_code(node.name) + "(" + self.write_code(node.args[0]) for i in node.args[1:]: return_str += ", " + self.write_code(i) @@ -262,7 +382,7 @@ def binop2string(self, node: ast_internal_classes.BinOp_Node): op = "<" if op == ".GT.": op = ">" - #TODO Add list of missing operators + # TODO Add list of missing operators left = self.write_code(node.lval) right = self.write_code(node.rval) @@ -272,7 +392,7 @@ def binop2string(self, node: ast_internal_classes.BinOp_Node): return left + op + right -def generate_memlet(op, top_sdfg, state): +def generate_memlet(op, top_sdfg, state, offset_normalization=False): if state.name_mapping.get(top_sdfg).get(get_name(op)) is not None: shape = top_sdfg.arrays[state.name_mapping[top_sdfg][get_name(op)]].shape elif state.name_mapping.get(state.globalsdfg).get(get_name(op)) is not None: @@ -281,18 +401,35 @@ def generate_memlet(op, top_sdfg, state): raise NameError("Variable name not found: ", get_name(op)) indices = [] if isinstance(op, ast_internal_classes.Array_Subscript_Node): - for i in op.indices: - tw = TaskletWriter([], [], top_sdfg, state.name_mapping) - text = tw.write_code(i) - #This might need to be replaced with the name in the context of the top/current sdfg - indices.append(sym.pystr_to_symbolic(text)) + for idx, i in enumerate(op.indices): + if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == 'ALL': + indices.append(None) + else: + tw = TaskletWriter([], [], top_sdfg, state.name_mapping, placeholders=state.placeholders, + placeholders_offsets=state.placeholders_offsets) + text_start = tw.write_code(i.range[0]) + text_end = tw.write_code(i.range[1]) + symb_start = sym.pystr_to_symbolic(text_start) + symb_end = sym.pystr_to_symbolic(text_end) + indices.append([symb_start, symb_end]) + else: + tw = TaskletWriter([], [], top_sdfg, state.name_mapping, placeholders=state.placeholders, + placeholders_offsets=state.placeholders_offsets) + text = tw.write_code(i) + # This might need to be replaced with the name in the context of the top/current sdfg + indices.append([sym.pystr_to_symbolic(text), sym.pystr_to_symbolic(text)]) memlet = '0' if len(shape) == 1: if shape[0] == 1: return memlet all_indices = indices + [None] * (len(shape) - len(indices)) - subset = subsets.Range([(i, i, 1) if i is not None else (1, s, 1) for i, s in zip(all_indices, shape)]) + if offset_normalization: + subset = subsets.Range( + [(i[0], i[1], 1) if i is not None else (0, s - 1, 1) for i, s in zip(all_indices, shape)]) + else: + subset = subsets.Range([(i[0], i[1], 1) if i is not None else (1, s, 1) for i, s in zip(all_indices, shape)]) return subset @@ -301,25 +438,36 @@ class ProcessedWriter(TaskletWriter): This class is derived from the TaskletWriter class and is used to write the code of a tasklet that's on an interstate edge rather than a computational tasklet. :note The only differences are in that the names for the sdfg mapping are used, and that the indices are considered to be one-bases rather than zero-based. """ - def __init__(self, sdfg: SDFG, mapping): + + def __init__(self, sdfg: SDFG, mapping, placeholders, placeholders_offsets, rename_dict): self.sdfg = sdfg + self.depth = 0 self.mapping = mapping + self.placeholders = placeholders + self.placeholders_offsets = placeholders_offsets + self.rename_dict = rename_dict self.ast_elements = { ast_internal_classes.BinOp_Node: self.binop2string, + ast_internal_classes.Actual_Arg_Spec_Node: self.actualarg2string, ast_internal_classes.Name_Node: self.name2string, ast_internal_classes.Name_Range_Node: self.namerange2string, ast_internal_classes.Int_Literal_Node: self.intlit2string, ast_internal_classes.Real_Literal_Node: self.floatlit2string, + ast_internal_classes.Double_Literal_Node: self.doublelit2string, ast_internal_classes.Bool_Literal_Node: self.boollit2string, + ast_internal_classes.Char_Literal_Node: self.charlit2string, ast_internal_classes.UnOp_Node: self.unop2string, ast_internal_classes.Array_Subscript_Node: self.arraysub2string, ast_internal_classes.Parenthesis_Expr_Node: self.parenthesis2string, ast_internal_classes.Call_Expr_Node: self.call2string, ast_internal_classes.ParDecl_Node: self.pardecl2string, + ast_internal_classes.Data_Ref_Node: self.dataref2string, } def name2string(self, node: ast_internal_classes.Name_Node): name = node.name + if name in self.rename_dict: + return str(self.rename_dict[name]) for i in self.sdfg.arrays: sdfg_name = self.mapping.get(self.sdfg).get(name) if sdfg_name == i: @@ -328,9 +476,9 @@ def name2string(self, node: ast_internal_classes.Name_Node): return name def arraysub2string(self, node: ast_internal_classes.Array_Subscript_Node): - str_to_return = self.write_code(node.name) + "[(" + self.write_code(node.indices[0]) + "+1)" + str_to_return = self.write_code(node.name) + "[(" + self.write_code(node.indices[0]) + ")" for i in node.indices[1:]: - str_to_return += ",( " + self.write_code(i) + "+1)" + str_to_return += ",( " + self.write_code(i) + ")" str_to_return += "]" return str_to_return @@ -384,3 +532,230 @@ def get(self, k): def __setitem__(self, k, v) -> None: assert isinstance(k, ast_internal_classes.Module_Node) return super().__setitem__(k, v) + + +class FunctionSubroutineLister: + def __init__(self): + self.list_of_functions = [] + self.names_in_functions = {} + self.list_of_subroutines = [] + self.names_in_subroutines = {} + self.list_of_types = [] + self.names_in_types = {} + + self.list_of_module_vars = [] + self.interface_blocks: Dict[str, List[Name]] = {} + + def get_functions_and_subroutines(self, node: Base): + for i in node.children: + if isinstance(i, Subroutine_Stmt): + subr_name = singular(children_of_type(i, Name)).string + self.names_in_subroutines[subr_name] = list_descendent_names(node) + self.names_in_subroutines[subr_name] += list_descendent_typenames(node) + self.list_of_subroutines.append(subr_name) + elif isinstance(i, Type_Declaration_Stmt): + if isinstance(node, Specification_Part) and isinstance(node.parent, Module): + self.list_of_module_vars.append(i) + elif isinstance(i, Derived_Type_Def): + name = i.children[0].children[1].string + self.names_in_types[name] = list_descendent_names(i) + self.names_in_types[name] += list_descendent_typenames(i) + self.list_of_types.append(name) + + + elif isinstance(i, Function_Stmt): + fn_name = singular(children_of_type(i, Name)).string + self.names_in_functions[fn_name] = list_descendent_names(node) + self.names_in_functions[fn_name] += list_descendent_typenames(node) + self.list_of_functions.append(fn_name) + elif isinstance(i, Interface_Block): + name = None + functions = [] + for j in i.children: + if isinstance(j, Interface_Stmt): + list_of_names = list_descendent_names(j) + if len(list_of_names) == 1: + name = list_of_names[0] + elif isinstance(j, Function_Body): + fn_stmt = singular(children_of_type(j, Function_Stmt)) + fn_name = singular(children_of_type(fn_stmt, Name)) + if fn_name not in functions: + functions.append(fn_name) + elif isinstance(j, Procedure_Stmt): + for k in j.children: + if k.__class__.__name__ == "Procedure_Name_List": + for n in children_of_type(k, Name): + if n not in functions: + functions.append(n) + if len(functions) > 0: + if name is None: + # Anonymous interface can show up multiple times. + name = '' + if name not in self.interface_blocks: + self.interface_blocks[name] = [] + self.interface_blocks[name].extend(functions) + else: + assert name not in self.interface_blocks + self.interface_blocks[name] = functions + elif isinstance(i, Base): + self.get_functions_and_subroutines(i) + + +def list_descendent_typenames(node: Base) -> List[str]: + def _list_descendent_typenames(_node: Base, _list_of_names: List[str]) -> List[str]: + for c in _node.children: + if isinstance(c, Type_Name): + if c.string not in _list_of_names: + _list_of_names.append(c.string) + elif isinstance(c, Base): + _list_descendent_typenames(c, _list_of_names) + return _list_of_names + + return _list_descendent_typenames(node, []) + + +def list_descendent_names(node: Base) -> List[str]: + def _list_descendent_names(_node: Base, _list_of_names: List[str]) -> List[str]: + for c in _node.children: + if isinstance(c, Name): + if c.string not in _list_of_names: + _list_of_names.append(c.string) + elif isinstance(c, Base): + _list_descendent_names(c, _list_of_names) + return _list_of_names + + return _list_descendent_names(node, []) + + +def get_defined_modules(node: Base) -> List[str]: + def _get_defined_modules(_node: Base, _defined_modules: List[str]) -> List[str]: + for m in _node.children: + if isinstance(m, Module_Stmt): + _defined_modules.extend(c.string for c in m.children if isinstance(c, Name)) + elif isinstance(m, Base): + _get_defined_modules(m, _defined_modules) + return _defined_modules + + return _get_defined_modules(node, []) + + +class UseAllPruneList: + def __init__(self, module: str, identifiers: List[str]): + """ + Keeps a list of referenced identifiers to intersect with the identifiers available in the module. + WARN: The list of referenced identifiers is taken from the scope of the invocation of "use", but may not be + entirely reliable. The parser should be able to function without this pruning (i.e., by really importing all). + """ + self.module = module + self.identifiers = identifiers + + +def get_used_modules(node: Base) -> Tuple[List[str], Dict[str, List[Union[UseAllPruneList, Base]]]]: + used_modules: List[str] = [] + objects_in_use: Dict[str, List[Union[UseAllPruneList, Base]]] = {} + + def _get_used_modules(_node: Base): + for m in _node.children: + if not isinstance(m, Base): + continue + if not isinstance(m, Use_Stmt): + # Subtree may have `use` statements. + _get_used_modules(m) + continue + nature, _, mod_name, _, olist = m.children + if nature is not None: + # TODO: Explain why intrinsic nodes are avoided. + if nature.string.lower() == "intrinsic": + continue + + mod_name = mod_name.string + used_modules.append(mod_name) + olist = atmost_one(children_of_type(m, 'Only_List')) + if not olist: + # TODO: Have better/clearer semantics. + if mod_name not in objects_in_use: + objects_in_use[mod_name] = [] + # A list of identifiers referred in the context of `_node`. If it's a specification part, then the + # context is its parent. If it's a module or a program, then `_node` itself is the context. + refs = list_descendent_names(_node.parent if isinstance(_node, Specification_Part) else _node) + # Add a special symbol to indicate that everything needs to be imported. + objects_in_use[mod_name].append(UseAllPruneList(mod_name, refs)) + else: + assert all(isinstance(c, (Name, Rename)) for c in olist.children) + used = [c if isinstance(c, Name) else c.children[2] for c in olist.children] + if not used: + continue + # Merge all the used item in one giant list. + if mod_name not in objects_in_use: + objects_in_use[mod_name] = [] + extend_with_new_items_from(objects_in_use[mod_name], used) + assert len(set([str(o) for o in objects_in_use[mod_name]])) == len(objects_in_use[mod_name]) + + _get_used_modules(node) + return used_modules, objects_in_use + + +def parse_module_declarations(program): + module_level_variables = {} + + for module in program.modules: + + module_name = module.name.name + from dace.frontend.fortran.ast_transforms import ModuleVarsDeclarations + + visitor = ModuleVarsDeclarations() # module_name) + if module.specification_part is not None: + visitor.visit(module.specification_part) + module_level_variables = {**module_level_variables, **visitor.scope_vars} + + return module_level_variables + + +T = TypeVar('T') + + +def singular(items: Iterator[T]) -> T: + """ + Asserts that any given iterator or generator `items` has exactly 1 item and returns that. + """ + it = atmost_one(items) + assert it is not None, f"`items` must not be empty." + return it + + +def atmost_one(items: Iterator[T]) -> Optional[T]: + """ + Asserts that any given iterator or generator `items` has exactly 1 item and returns that. + """ + # We might get one item. + try: + it = next(items) + except StopIteration: + # No items found. + return None + # But not another one. + try: + nit = next(items) + except StopIteration: + # I.e., we must have exhausted the iterator. + return it + raise ValueError(f"`items` must have at most 1 item, got: {it}, {nit}, ...") + + +def children_of_type(node: Base, typ: Union[str, Type[T], Tuple[Type, ...]]) -> Iterator[T]: + """ + Returns a generator over the children of `node` that are of type `typ`. + """ + if isinstance(typ, str): + return (c for c in node.children if type(c).__name__ == typ) + else: + return (c for c in node.children if isinstance(c, typ)) + + +def extend_with_new_items_from(lst: List[T], items: Iterable[T]): + """ + Extends the list `lst` with new items from `items` (i.e., if it does not exist there already). + """ + for it in items: + if it not in lst: + lst.append(it) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 6b14f63edd..87e5881175 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -1,44 +1,279 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. -from venv import create +import copy +from dataclasses import dataclass +import os import warnings +from copy import deepcopy as dpcp +from itertools import chain +from pathlib import Path +from typing import List, Optional, Set, Dict, Tuple, Union -from dace.data import Scalar +import networkx as nx +from fparser.common.readfortran import FortranFileReader as ffr, FortranStringReader, FortranFileReader +from fparser.common.readfortran import FortranStringReader as fsr +from fparser.two.Fortran2003 import Program, Name, Subroutine_Subprogram, Module_Stmt +from fparser.two.parser import ParserFactory as pf, ParserFactory +from fparser.two.symbol_table import SymbolTable +from fparser.two.utils import Base, walk import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes -from typing import List, Optional, Tuple, Set -from dace import dtypes from dace import Language as lang +from dace import SDFG, InterstateEdge, Memlet, pointer, nodes, SDFGState from dace import data as dat -from dace import SDFG, InterstateEdge, Memlet, pointer, nodes +from dace import dtypes +from dace import subsets as subs from dace import symbolic as sym -from dace.sdfg.state import ControlFlowRegion, LoopRegion -from copy import deepcopy as dpcp - +from dace.data import Scalar, Structure +from dace.frontend.fortran.ast_desugaring import SPEC, ENTRY_POINT_OBJECT_TYPES, find_name_of_stmt, find_name_of_node, \ + identifier_specs, append_children, correct_for_function_calls, remove_access_statements, sort_modules, \ + deconstruct_enums, deconstruct_interface_calls, deconstruct_procedure_calls, prune_unused_objects, \ + deconstruct_associations, assign_globally_unique_subprogram_names, assign_globally_unique_variable_names, \ + consolidate_uses, prune_branches, const_eval_nodes, lower_identifier_names, \ + remove_access_statements, ident_spec, NAMED_STMTS_OF_INTEREST_TYPES +from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node +from dace.frontend.fortran.ast_utils import UseAllPruneList, children_of_type +from dace.frontend.fortran.intrinsics import IntrinsicSDFGTransformation, NeedsTypeInferenceException from dace.properties import CodeBlock -from fparser.two.parser import ParserFactory as pf -from fparser.common.readfortran import FortranStringReader as fsr -from fparser.common.readfortran import FortranFileReader as ffr -from fparser.two.symbol_table import SymbolTable + +global_struct_instance_counter = 0 + + +def find_access_in_destinations(substate, substate_destinations, name): + wv = None + already_there = False + for i in substate_destinations: + if i.data == name: + wv = i + already_there = True + break + if not already_there: + wv = substate.add_write(name) + return wv, already_there + + +def find_access_in_sources(substate, substate_sources, name): + re = None + already_there = False + for i in substate_sources: + if i.data == name: + re = i + already_there = True + break + if not already_there: + re = substate.add_read(name) + return re, already_there + + +def add_views_recursive(sdfg, name, datatype_to_add, struct_views, name_mapping, registered_types, chain, + actual_offsets_per_sdfg, names_of_object_in_parent_sdfg, actual_offsets_per_parent_sdfg): + if not isinstance(datatype_to_add, dat.Structure): + # print("Not adding: ", str(datatype_to_add)) + if isinstance(datatype_to_add, dat.ContainerArray): + datatype_to_add = datatype_to_add.stype + for i in datatype_to_add.members: + current_dtype = datatype_to_add.members[i].dtype + for other_type in registered_types: + if current_dtype.dtype == registered_types[other_type].dtype: + other_type_obj = registered_types[other_type] + add_views_recursive(sdfg, name, datatype_to_add.members[i], struct_views, name_mapping, + registered_types, chain + [i], actual_offsets_per_sdfg, + names_of_object_in_parent_sdfg, actual_offsets_per_parent_sdfg) + # for j in other_type_obj.members: + # sdfg.add_view(name_mapping[name] + "_" + i +"_"+ j,other_type_obj.members[j].shape,other_type_obj.members[j].dtype) + # name_mapping[name + "_" + i +"_"+ j] = name_mapping[name] + "_" + i +"_"+ j + # struct_views[name_mapping[name] + "_" + i+"_"+ j]=[name_mapping[name],i,j] + if len(chain) > 0: + join_chain = "_" + "_".join(chain) + else: + join_chain = "" + current_member = datatype_to_add.members[i] + + if str(datatype_to_add.members[i].dtype.base_type) in registered_types: + + view_to_member = dat.View.view(datatype_to_add.members[i]) + if sdfg.arrays.get(name_mapping[name] + join_chain + "_" + i) is None: + sdfg.arrays[name_mapping[name] + join_chain + "_" + i] = view_to_member + else: + if sdfg.arrays.get(name_mapping[name] + join_chain + "_" + i) is None: + sdfg.add_view(name_mapping[name] + join_chain + "_" + i, datatype_to_add.members[i].shape, + datatype_to_add.members[i].dtype, strides=datatype_to_add.members[i].strides) + if names_of_object_in_parent_sdfg.get(name_mapping[name]) is not None: + if actual_offsets_per_parent_sdfg.get( + names_of_object_in_parent_sdfg[name_mapping[name]] + join_chain + "_" + i) is not None: + actual_offsets_per_sdfg[name_mapping[name] + join_chain + "_" + i] = actual_offsets_per_parent_sdfg[ + names_of_object_in_parent_sdfg[name_mapping[name]] + join_chain + "_" + i] + else: + # print("No offsets in sdfg: ",sdfg.name ," for: ",names_of_object_in_parent_sdfg[name_mapping[name]]+ join_chain + "_" + i) + actual_offsets_per_sdfg[name_mapping[name] + join_chain + "_" + i] = [1] * len( + datatype_to_add.members[i].shape) + name_mapping[name_mapping[name] + join_chain + "_" + i] = name_mapping[name] + join_chain + "_" + i + struct_views[name_mapping[name] + join_chain + "_" + i] = [name_mapping[name]] + chain + [i] + + +def add_deferred_shape_assigns_for_structs(structures: ast_transforms.Structures, + decl: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, + assign_state: SDFGState, name: str, name_: str, placeholders, + placeholders_offsets, object, names_to_replace, actual_offsets_per_sdfg): + if not structures.is_struct(decl.type): + # print("Not adding defferred shape assigns for: ", decl.type,decl.name) + return + + if isinstance(object, dat.ContainerArray): + struct_type = object.stype + else: + struct_type = object + global global_struct_instance_counter + local_counter = global_struct_instance_counter + global_struct_instance_counter += 1 + overall_ast_struct_type = structures.get_definition(decl.type) + counter = 0 + listofmember = list(struct_type.members) + # print("Struct: "+decl.name +" Struct members: "+ str(len(listofmember))+ " Definition members: "+str(len(list(overall_ast_struct_type.vars.items())))) + + for ast_struct_type in overall_ast_struct_type.vars.items(): + ast_struct_type = ast_struct_type[1] + var = struct_type.members[ast_struct_type.name] + + if isinstance(var, dat.ContainerArray): + var_type = var.stype + else: + var_type = var + + # print(ast_struct_type.name,var_type.__class__) + if isinstance(object.members[ast_struct_type.name], dat.Structure): + + add_deferred_shape_assigns_for_structs(structures, ast_struct_type, sdfg, assign_state, + f"{name}->{ast_struct_type.name}", f"{ast_struct_type.name}_{name_}", + placeholders, placeholders_offsets, + object.members[ast_struct_type.name], names_to_replace, + actual_offsets_per_sdfg) + elif isinstance(var_type, dat.Structure): + add_deferred_shape_assigns_for_structs(structures, ast_struct_type, sdfg, assign_state, + f"{name}->{ast_struct_type.name}", f"{ast_struct_type.name}_{name_}", + placeholders, placeholders_offsets, var_type, names_to_replace, + actual_offsets_per_sdfg) + # print(ast_struct_type) + # print(ast_struct_type.__class__) + + if ast_struct_type.sizes is None or len(ast_struct_type.sizes) == 0: + continue + offsets_to_replace = [] + sanity_count = 0 + + for offset in ast_struct_type.offsets: + if isinstance(offset, ast_internal_classes.Name_Node): + if hasattr(offset, "name"): + if sdfg.symbols.get(offset.name) is None: + sdfg.add_symbol(offset.name, dtypes.int32) + sanity_count += 1 + if offset.name.startswith('__f2dace_SOA'): + newoffset = offset.name + "_" + name_ + "_" + str(local_counter) + sdfg.append_global_code(f"{dtypes.int32.ctype} {newoffset};\n") + # prog hack + if name.endswith("prog"): + sdfg.append_init_code(f"{newoffset} = {name}[0]->{offset.name};\n") + else: + sdfg.append_init_code(f"{newoffset} = {name}->{offset.name};\n") + + sdfg.add_symbol(newoffset, dtypes.int32) + offsets_to_replace.append(newoffset) + names_to_replace[offset.name] = newoffset + else: + # print("not replacing",offset.name) + offsets_to_replace.append(offset.name) + else: + sanity_count += 1 + # print("not replacing not namenode",offset) + offsets_to_replace.append(offset) + if sanity_count == len(ast_struct_type.offsets): + # print("adding offsets for: "+name.replace("->","_")+"_"+ast_struct_type.name) + actual_offsets_per_sdfg[name.replace("->", "_") + "_" + ast_struct_type.name] = offsets_to_replace + + # for assumed shape, all vars starts with the same prefix + for size in ast_struct_type.sizes: + if isinstance(size, ast_internal_classes.Name_Node): # and size.name.startswith('__f2dace_A'): + + if hasattr(size, "name"): + if sdfg.symbols.get(size.name) is None: + # new_name=sdfg._find_new_name(size.name) + sdfg.add_symbol(size.name, dtypes.int32) + + if size.name.startswith('__f2dace_SA'): + # newsize=ast_internal_classes.Name_Node(name=size.name+"_"+str(local_counter),parent=size.parent,type=size.type) + newsize = size.name + "_" + name_ + "_" + str(local_counter) + names_to_replace[size.name] = newsize + # var_type.sizes[var_type.sizes.index(size)]=newsize + sdfg.append_global_code(f"{dtypes.int32.ctype} {newsize};\n") + if name.endswith("prog"): + sdfg.append_init_code(f"{newsize} = {name}[0]->{size.name};\n") + else: + sdfg.append_init_code(f"{newsize} = {name}->{size.name};\n") + sdfg.add_symbol(newsize, dtypes.int32) + if isinstance(object, dat.Structure): + shape2 = dpcp(object.members[ast_struct_type.name].shape) + else: + shape2 = dpcp(object.stype.members[ast_struct_type.name].shape) + shapelist = list(shape2) + shapelist[ast_struct_type.sizes.index(size)] = sym.pystr_to_symbolic(newsize) + shape_replace = tuple(shapelist) + viewname = f"{name}->{ast_struct_type.name}" + + viewname = viewname.replace("->", "_") + # view=sdfg.arrays[viewname] + strides = [dat._prod(shapelist[:i]) for i in range(len(shapelist))] + if isinstance(object.members[ast_struct_type.name], dat.ContainerArray): + tmpobject = dat.ContainerArray(object.members[ast_struct_type.name].stype, shape_replace, + strides=strides) + + + elif isinstance(object.members[ast_struct_type.name], dat.Array): + tmpobject = dat.Array(object.members[ast_struct_type.name].dtype, shape_replace, + strides=strides) + + else: + raise ValueError("Unknown type" + str(tmpobject.__class__)) + object.members.pop(ast_struct_type.name) + object.members[ast_struct_type.name] = tmpobject + tmpview = dat.View.view(object.members[ast_struct_type.name]) + if sdfg.arrays.get(viewname) is not None: + del sdfg.arrays[viewname] + sdfg.arrays[viewname] = tmpview + # if placeholders.get(size.name) is not None: + # placeholders[newsize]=placeholders[size.name] class AST_translator: """ This class is responsible for translating the internal AST into a SDFG. """ - def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_explicit_cf: bool = False): + + def __init__(self, source: str, multiple_sdfgs: bool = False, startpoint=None, sdfg_path=None, + toplevel_subroutine: Optional[str] = None, subroutine_used_names: Optional[Set[str]] = None, + normalize_offsets=False, do_not_make_internal_variables_argument: bool = False): """ :ast: The internal fortran AST to be used for translation :source: The source file name from which the AST was generated + :do_not_make_internal_variables_argument: Avoid turning internal variables of the entry point into arguments. + This essentially avoids the hack with `transient_mode = False`, since we can rely on `startpoint` for + arbitrary entry point anyway. """ - self.tables = ast.tables + # TODO: Refactor the callers who rely on the hack with `transient_mode = False`, then remove the + # `do_not_make_internal_variables_argument` argument entirely, since we don't need it at that point. + self.sdfg_path = sdfg_path + self.count_of_struct_symbols_lifted = 0 + self.registered_types = {} + self.transient_mode = True + self.startpoint = startpoint self.top_level = None self.globalsdfg = None - self.functions_and_subroutines = ast.functions_and_subroutines + self.multiple_sdfgs = multiple_sdfgs self.name_mapping = ast_utils.NameMap() + self.actual_offsets_per_sdfg = {} + self.names_of_object_in_parent_sdfg = {} self.contexts = {} self.views = 0 self.libstates = [] @@ -47,11 +282,25 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_expl self.all_array_names = [] self.last_sdfg_states = {} self.last_loop_continues = {} + self.last_loop_continues_stack = {} + self.already_has_edge_back_continue = {} self.last_loop_breaks = {} self.last_returns = {} self.module_vars = [] + self.sdfgs_count = 0 self.libraries = {} + self.local_not_transient_because_assign = {} + self.struct_views = {} self.last_call_expression = {} + self.struct_view_count = 0 + self.structures = None + self.placeholders = None + self.placeholders_offsets = None + self.replace_names = {} + self.toplevel_subroutine = toplevel_subroutine + self.subroutine_used_names = subroutine_used_names + self.normalize_offsets = normalize_offsets + self.do_not_make_internal_variables_argument = do_not_make_internal_variables_argument self.ast_elements = { ast_internal_classes.If_Stmt_Node: self.ifstmt2sdfg, ast_internal_classes.For_Stmt_Node: self.forstmt2sdfg, @@ -68,8 +317,11 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_expl ast_internal_classes.Write_Stmt_Node: self.write2sdfg, ast_internal_classes.Allocate_Stmt_Node: self.allocate2sdfg, ast_internal_classes.Break_Node: self.break2sdfg, + ast_internal_classes.Continue_Node: self.continue2sdfg, + ast_internal_classes.Derived_Type_Def_Node: self.derivedtypedef2sdfg, + ast_internal_classes.Pointer_Assignment_Stmt_Node: self.pointerassignment2sdfg, + ast_internal_classes.While_Stmt_Node: self.whilestmt2sdfg, } - self.use_explicit_cf = use_explicit_cf def get_dace_type(self, type): """ @@ -77,7 +329,17 @@ def get_dace_type(self, type): by referencing the ast_utils.fortrantypes2dacetypes dictionary. """ if isinstance(type, str): - return ast_utils.fortrantypes2dacetypes[type] + if type in ast_utils.fortrantypes2dacetypes: + return ast_utils.fortrantypes2dacetypes[type] + elif type in self.registered_types: + return self.registered_types[type] + else: + # TODO: This is bandaid. + if type == "VOID": + return ast_utils.fortrantypes2dacetypes["DOUBLE"] + raise ValueError("Unknown type " + type) + else: + raise ValueError("Unknown type " + type) def get_name_mapping_in_context(self, sdfg: SDFG): """ @@ -119,9 +381,9 @@ def get_memlet_range(self, sdfg: SDFG, variables: List[ast_internal_classes.FNod for o_v in variables: if o_v.name == var_name_tasklet: - return ast_utils.generate_memlet(o_v, sdfg, self) + return ast_utils.generate_memlet(o_v, sdfg, self, self.normalize_offsets) - def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG, cfg: Optional[ControlFlowRegion] = None): + def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG): """ This function is responsible for translating the AST into a SDFG. :param node: The node to be translated @@ -130,17 +392,15 @@ def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG, cfg: Optional[ :note: This function will call the appropriate function for the node type :note: The dictionary ast_elements, part of the class itself contains all functions that are called for the different node types """ - if not cfg: - cfg = sdfg if node.__class__ in self.ast_elements: - self.ast_elements[node.__class__](node, sdfg, cfg) + self.ast_elements[node.__class__](node, sdfg) elif isinstance(node, list): for i in node: - self.translate(i, sdfg, cfg) + self.translate(i, sdfg) else: warnings.warn(f"WARNING: {node.__class__.__name__}") - def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): """ This function is responsible for translating the Fortran AST into a SDFG. :param node: The node to be translated @@ -151,28 +411,218 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: Con """ self.globalsdfg = sdfg for i in node.modules: - for j in i.specification_part.typedecls: - self.translate(j, sdfg, cfg) - for k in j.vardecl: - self.module_vars.append((k.name, i.name)) - for j in i.specification_part.symbols: - self.translate(j, sdfg, cfg) - for k in j.vardecl: - self.module_vars.append((k.name, i.name)) - for j in i.specification_part.specifications: - self.translate(j, sdfg, cfg) - for k in j.vardecl: - self.module_vars.append((k.name, i.name)) - - for i in node.main_program.specification_part.typedecls: - self.translate(i, sdfg, cfg) - for i in node.main_program.specification_part.symbols: - self.translate(i, sdfg, cfg) - for i in node.main_program.specification_part.specifications: - self.translate(i, sdfg, cfg) - self.translate(node.main_program.execution_part.execution, sdfg, cfg) - - def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG, cfg: ControlFlowRegion): + structs_lister = ast_transforms.StructLister() + if i.specification_part is not None: + structs_lister.visit(i.specification_part) + struct_dep_graph = nx.DiGraph() + for ii, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(ii) + struct_deps = struct_deps_finder.structs_used + # print(struct_deps) + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + parse_order = list(reversed(list(nx.topological_sort(struct_dep_graph)))) + for jj in parse_order: + for j in i.specification_part.typedecls: + if j.name.name == jj: + self.translate(j, sdfg) + if j.__class__.__name__ != "Derived_Type_Def_Node": + for k in j.vardecl: + self.module_vars.append((k.name, i.name)) + if i.specification_part is not None: + + # this works with CloudSC + # unsure about ICON + self.transient_mode = self.do_not_make_internal_variables_argument + + for j in i.specification_part.symbols: + self.translate(j, sdfg) + if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node): + self.module_vars.append((j.name, i.name)) + elif isinstance(j, ast_internal_classes.Symbol_Decl_Node): + self.module_vars.append((j.name, i.name)) + else: + raise ValueError("Unknown symbol type") + for j in i.specification_part.specifications: + self.translate(j, sdfg) + for k in j.vardecl: + self.module_vars.append((k.name, i.name)) + # this works with CloudSC + # unsure about ICON + self.transient_mode = True + ast_utils.add_simple_state_to_sdfg(self, sdfg, "GlobalDefEnd") + if self.startpoint is None: + self.startpoint = node.main_program + assert self.startpoint is not None, "No main program or start point found" + + if self.startpoint.specification_part is not None: + # this works with CloudSC + # unsure about ICON + self.transient_mode = self.do_not_make_internal_variables_argument + + for i in self.startpoint.specification_part.typedecls: + self.translate(i, sdfg) + for i in self.startpoint.specification_part.symbols: + self.translate(i, sdfg) + + for i in self.startpoint.specification_part.specifications: + self.translate(i, sdfg) + for i in self.startpoint.specification_part.specifications: + ast_utils.add_simple_state_to_sdfg(self, sdfg, "start_struct_size") + assign_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "assign_struct_sizes") + for decl in i.vardecl: + if decl.name in sdfg.symbols: + continue + add_deferred_shape_assigns_for_structs(self.structures, decl, sdfg, assign_state, decl.name, + decl.name, self.placeholders, + self.placeholders_offsets, + sdfg.arrays[self.name_mapping[sdfg][decl.name]], + self.replace_names, + self.actual_offsets_per_sdfg[sdfg]) + + if not isinstance(self.startpoint, Main_Program_Node): + # this works with CloudSC + # unsure about ICON + arg_names = [ast_utils.get_name(i) for i in self.startpoint.args] + for arr_name, arr in sdfg.arrays.items(): + + if arr.transient and arr_name in arg_names: + print(f"Changing the transient status to false of {arr_name} because it's a function argument") + arr.transient = False + + self.transient_mode = True + self.translate(self.startpoint.execution_part.execution, sdfg) + + def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG): + """ + This function is responsible for translating Fortran pointer assignments into a SDFG. + :param node: The node to be translated + :param sdfg: The SDFG to which the node should be translated + """ + if self.name_mapping[sdfg][node.name_pointer.name] in sdfg.arrays: + shapenames = [sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape[i] for i in + range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape))] + offsetnames = self.actual_offsets_per_sdfg[sdfg][node.name_pointer.name] + [sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset[i] for i in + range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset))] + # for i in shapenames: + # if str(i) in sdfg.symbols: + # sdfg.symbols.pop(str(i)) + # if sdfg.parent_nsdfg_node is not None: + # if str(i) in sdfg.parent_nsdfg_node.symbol_mapping: + # sdfg.parent_nsdfg_node.symbol_mapping.pop(str(i)) + + # for i in offsetnames: + # if str(i) in sdfg.symbols: + # sdfg.symbols.pop(str(i)) + # if sdfg.parent_nsdfg_node is not None: + # if str(i) in sdfg.parent_nsdfg_node.symbol_mapping: + # sdfg.parent_nsdfg_node.symbol_mapping.pop(str(i)) + sdfg.arrays.pop(self.name_mapping[sdfg][node.name_pointer.name]) + if isinstance(node.name_target, ast_internal_classes.Data_Ref_Node): + if node.name_target.parent_ref.name not in self.name_mapping[sdfg]: + raise ValueError("Unknown variable " + node.name_target.name) + if isinstance(node.name_target.part_ref, ast_internal_classes.Data_Ref_Node): + self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name] + # self.replace_names[node.name_pointer.name]=self.name_mapping[sdfg][node.name_target.parent_ref.name+"_"+node.name_target.part_ref.parent_ref.name+"_"+node.name_target.part_ref.part_ref.name] + target = sdfg.arrays[self.name_mapping[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name]] + # for i in self.actual_offsets_per_sdfg[sdfg]: + # print(i) + actual_offsets = self.actual_offsets_per_sdfg[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name] + + for i in shapenames: + self.replace_names[str(i)] = str(target.shape[shapenames.index(i)]) + for i in offsetnames: + self.replace_names[str(i)] = str(actual_offsets[offsetnames.index(i)]) + else: + self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] + self.replace_names[node.name_pointer.name] = self.name_mapping[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] + target = sdfg.arrays[ + self.name_mapping[sdfg][node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name]] + actual_offsets = self.actual_offsets_per_sdfg[sdfg][ + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] + for i in shapenames: + self.replace_names[str(i)] = str(target.shape[shapenames.index(i)]) + for i in offsetnames: + self.replace_names[str(i)] = str(actual_offsets[offsetnames.index(i)]) + + elif isinstance(node.name_pointer, ast_internal_classes.Data_Ref_Node): + raise ValueError("Not imlemented yet") + + else: + if node.name_target.name not in self.name_mapping[sdfg]: + raise ValueError("Unknown variable " + node.name_target.name) + found = False + for i in self.unallocated_arrays: + if i[0] == node.name_pointer.name: + if found: + raise ValueError("Multiple unallocated arrays with the same name") + fount = True + self.unallocated_arrays.remove(i) + self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][node.name_target.name] + + def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG): + """ + This function is responsible for registering Fortran derived type declarations into a SDFG as nested data types. + :param node: The node to be translated + :param sdfg: The SDFG to which the node should be translated + """ + name = node.name.name + if node.component_part is None: + components = [] + else: + components = node.component_part.component_def_stmts + dict_setup = {} + for i in components: + j = i.vars + for k in j.vardecl: + complex_datatype = False + datatype = self.get_dace_type(k.type) + if isinstance(datatype, dat.Structure): + complex_datatype = True + if k.sizes is not None: + sizes = [] + offset = [] + offset_value = 0 if self.normalize_offsets else -1 + for i in k.sizes: + tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) + text = tw.write_code(i) + sizes.append(sym.pystr_to_symbolic(text)) + offset.append(offset_value) + strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] + if not complex_datatype: + dict_setup[k.name] = dat.Array( + datatype, + sizes, + strides=strides, + offset=offset, + ) + else: + dict_setup[k.name] = dat.ContainerArray(datatype, sizes, strides=strides, offset=offset) + + else: + if not complex_datatype: + dict_setup[k.name] = dat.Scalar(datatype) + else: + dict_setup[k.name] = datatype + + structure_obj = Structure(dict_setup, name) + self.registered_types[name] = structure_obj + + def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG): """ This function is responsible for translating Fortran basic blocks into a SDFG. :param node: The node to be translated @@ -180,9 +630,9 @@ def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: """ for i in node.execution: - self.translate(i, sdfg, cfg) + self.translate(i, sdfg) - def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG): """ This function is responsible for translating Fortran allocate statements into a SDFG. :param node: The node to be translated @@ -195,11 +645,13 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF datatype = j[1] transient = j[3] self.unallocated_arrays.remove(j) - offset_value = -1 + offset_value = 0 if self.normalize_offsets else -1 sizes = [] offset = [] for j in i.shape.shape_list: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping) + tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) text = tw.write_code(j) sizes.append(sym.pystr_to_symbolic(text)) offset.append(offset_value) @@ -218,12 +670,12 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF strides=strides, transient=transient) + def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG): + # TODO implement + print("Uh oh") + # raise NotImplementedError("Fortran write statements are not implemented yet") - def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): - #TODO implement - raise NotImplementedError("Fortran write statements are not implemented yet") - - def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): """ This function is responsible for translating Fortran if statements into a SDFG. :param node: The node to be translated @@ -231,122 +683,158 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: """ name = f"If_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"Begin{name}") - guard_substate = cfg.add_state(f"Guard{name}") - cfg.add_edge(begin_state, guard_substate, InterstateEdge()) + begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"Begin{name}") + guard_substate = sdfg.add_state(f"Guard{name}") + sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) + condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, self.placeholders, self.placeholders_offsets, + self.replace_names).write_code(node.cond) - body_ifstart_state = cfg.add_state(f"BodyIfStart{name}") - self.last_sdfg_states[cfg] = body_ifstart_state - self.translate(node.body, sdfg, cfg) - final_substate = cfg.add_state(f"MergeState{name}") + body_ifstart_state = sdfg.add_state(f"BodyIfStart{name}") + self.last_sdfg_states[sdfg] = body_ifstart_state + self.translate(node.body, sdfg) + final_substate = sdfg.add_state(f"MergeState{name}") - cfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) + sdfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) - if self.last_sdfg_states[cfg] not in [ - self.last_loop_breaks.get(cfg), - self.last_loop_continues.get(cfg), - self.last_returns.get(cfg) + if self.last_sdfg_states[sdfg] not in [ + self.last_loop_breaks.get(sdfg), + self.last_loop_continues.get(sdfg), + self.last_returns.get(sdfg), + self.already_has_edge_back_continue.get(sdfg) ]: - body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"BodyIfEnd{name}") - cfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) + body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyIfEnd{name}") + sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) if len(node.body_else.execution) > 0: name_else = f"Else_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - body_elsestart_state = cfg.add_state("BodyElseStart" + name_else) - self.last_sdfg_states[cfg] = body_elsestart_state - self.translate(node.body_else, sdfg, cfg) - body_elseend_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"BodyElseEnd{name_else}") - cfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) - cfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) + body_elsestart_state = sdfg.add_state("BodyElseStart" + name_else) + self.last_sdfg_states[sdfg] = body_elsestart_state + self.translate(node.body_else, sdfg) + body_elseend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyElseEnd{name_else}") + sdfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) + sdfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) else: - cfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) - self.last_sdfg_states[cfg] = final_substate + sdfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) + self.last_sdfg_states[sdfg] = final_substate + + def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG): + + # raise NotImplementedError("Fortran while statements are not implemented yet") + name = "While_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) + begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "Begin" + name) + guard_substate = sdfg.add_state("Guard" + name) + final_substate = sdfg.add_state("Merge" + name) + self.last_sdfg_states[sdfg] = final_substate + + sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) + + condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.cond) + + + + begin_loop_state = sdfg.add_state("BeginWhile" + name) + end_loop_state = sdfg.add_state("EndWhile" + name) + self.last_sdfg_states[sdfg] = begin_loop_state + self.last_loop_continues[sdfg] = end_loop_state + if self.last_loop_continues_stack.get(sdfg) is None: + self.last_loop_continues_stack[sdfg] = [] + self.last_loop_continues_stack[sdfg].append(end_loop_state) + self.translate(node.body, sdfg) + + sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) + sdfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) + sdfg.add_edge(end_loop_state, guard_substate, InterstateEdge()) + sdfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) + self.last_sdfg_states[sdfg] = final_substate + + if len(self.last_loop_continues_stack[sdfg]) > 0: + self.last_loop_continues[sdfg] = self.last_loop_continues_stack[sdfg][-1] + else: + self.last_loop_continues[sdfg] = None + - def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): """ This function is responsible for translating Fortran for statements into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated """ - if not self.use_explicit_cf: - declloop = False - name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) - begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, "Begin" + name) - guard_substate = cfg.add_state("Guard" + name) - final_substate = cfg.add_state("Merge" + name) - self.last_sdfg_states[cfg] = final_substate - decl_node = node.init - entry = {} - if isinstance(decl_node, ast_internal_classes.BinOp_Node): - if sdfg.symbols.get(decl_node.lval.name) is not None: - iter_name = decl_node.lval.name - elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: - iter_name = self.name_mapping[sdfg][decl_node.lval.name] - else: - raise ValueError("Unknown variable " + decl_node.lval.name) - entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(decl_node.rval) - - cfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) - - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) - - increment = "i+0+1" - if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.iter.rval) - entry = {iter_name: increment} - - begin_loop_state = cfg.add_state("BeginLoop" + name) - end_loop_state = cfg.add_state("EndLoop" + name) - self.last_sdfg_states[cfg] = begin_loop_state - self.last_loop_continues[cfg] = final_substate - self.translate(node.body, sdfg, cfg) - - cfg.add_edge(self.last_sdfg_states[cfg], end_loop_state, InterstateEdge()) - cfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) - cfg.add_edge(end_loop_state, guard_substate, InterstateEdge(assignments=entry)) - cfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) - self.last_sdfg_states[cfg] = final_substate + declloop = False + name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) + begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "Begin" + name) + guard_substate = sdfg.add_state("Guard" + name) + final_substate = sdfg.add_state("Merge" + name) + self.last_sdfg_states[sdfg] = final_substate + decl_node = node.init + entry = {} + if isinstance(decl_node, ast_internal_classes.BinOp_Node): + if sdfg.symbols.get(decl_node.lval.name) is not None: + iter_name = decl_node.lval.name + elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: + iter_name = self.name_mapping[sdfg][decl_node.lval.name] + else: + raise ValueError("Unknown variable " + decl_node.lval.name) + entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(decl_node.rval) + + sdfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) + + condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.cond) + + increment = "i+0+1" + if isinstance(node.iter, ast_internal_classes.BinOp_Node): + increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.iter.rval) + entry = {iter_name: increment} + + begin_loop_state = sdfg.add_state("BeginLoop" + name) + end_loop_state = sdfg.add_state("EndLoop" + name) + self.last_sdfg_states[sdfg] = begin_loop_state + self.last_loop_continues[sdfg] = end_loop_state + if self.last_loop_continues_stack.get(sdfg) is None: + self.last_loop_continues_stack[sdfg] = [] + self.last_loop_continues_stack[sdfg].append(end_loop_state) + self.translate(node.body, sdfg) + + sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) + sdfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) + sdfg.add_edge(end_loop_state, guard_substate, InterstateEdge(assignments=entry)) + sdfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) + self.last_sdfg_states[sdfg] = final_substate + self.last_loop_continues_stack[sdfg].pop() + if len(self.last_loop_continues_stack[sdfg]) > 0: + self.last_loop_continues[sdfg] = self.last_loop_continues_stack[sdfg][-1] else: - name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) - decl_node = node.init - entry = {} - if isinstance(decl_node, ast_internal_classes.BinOp_Node): - if sdfg.symbols.get(decl_node.lval.name) is not None: - iter_name = decl_node.lval.name - elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: - iter_name = self.name_mapping[sdfg][decl_node.lval.name] - else: - raise ValueError("Unknown variable " + decl_node.lval.name) - entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(decl_node.rval) - - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) - - increment = "i+0+1" - if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.iter.rval) - - loop_region = LoopRegion(name, condition, iter_name, f"{iter_name} = {entry[iter_name]}", - f"{iter_name} = {increment}") - is_start = self.last_sdfg_states.get(cfg) is None - cfg.add_node(loop_region, is_start_block=is_start) - if not is_start: - cfg.add_edge(self.last_sdfg_states[cfg], loop_region, InterstateEdge()) - self.last_sdfg_states[cfg] = loop_region + self.last_loop_continues[sdfg] = None - begin_loop_state = loop_region.add_state("BeginLoop" + name, is_start_block=True) - self.last_sdfg_states[loop_region] = begin_loop_state - - self.translate(node.body, sdfg, loop_region) - - def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): """ This function is responsible for translating Fortran symbol declarations into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated """ + if node.name == "modname": return + + if node.name.startswith("__f2dace_A_"): + # separate name by removing the prefix and the suffix starting with _d_ + array_name = node.name[11:] + array_name = array_name[:array_name.index("_d_")] + if array_name in sdfg.arrays: + return # already declared + if node.name.startswith("__f2dace_OA_"): + # separate name by removing the prefix and the suffix starting with _d_ + array_name = node.name[12:] + array_name = array_name[:array_name.index("_d_")] + if array_name in sdfg.arrays: + return if self.contexts.get(sdfg.name) is None: self.contexts[sdfg.name] = ast_utils.Context(name=sdfg.name) @@ -354,30 +842,39 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, c if isinstance(node.init, ast_internal_classes.Int_Literal_Node) or isinstance( node.init, ast_internal_classes.Real_Literal_Node): self.contexts[sdfg.name].constants[node.name] = node.init.value - if isinstance(node.init, ast_internal_classes.Name_Node): + elif isinstance(node.init, ast_internal_classes.Name_Node): self.contexts[sdfg.name].constants[node.name] = self.contexts[sdfg.name].constants[node.init.name] + else: + tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) + if node.init is not None: + text = tw.write_code(node.init) + self.contexts[sdfg.name].constants[node.name] = sym.pystr_to_symbolic(text) + datatype = self.get_dace_type(node.type) if node.name not in sdfg.symbols: sdfg.add_symbol(node.name, datatype) - if self.last_sdfg_states.get(cfg) is None: - bstate = cfg.add_state("SDFGbegin", is_start_state=True) - self.last_sdfg_states[cfg] = bstate + if self.last_sdfg_states.get(sdfg) is None: + bstate = sdfg.add_state("SDFGbegin", is_start_state=True) + self.last_sdfg_states[sdfg] = bstate if node.init is not None: - substate = cfg.add_state(f"Dummystate_{node.name}") - increment = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping).write_code(node.init) + substate = sdfg.add_state(f"Dummystate_{node.name}") + increment = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.init) entry = {node.name: increment} - cfg.add_edge(self.last_sdfg_states[cfg], substate, InterstateEdge(assignments=entry)) - self.last_sdfg_states[cfg] = substate + sdfg.add_edge(self.last_sdfg_states[sdfg], substate, InterstateEdge(assignments=entry)) + self.last_sdfg_states[sdfg] = substate - def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG): return NotImplementedError( "Symbol_Decl_Node not implemented. This should be done via a transformation that itemizes the constant array." ) - def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG, - cfg: ControlFlowRegion): + def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG): """ This function is responsible for translating Fortran subroutine declarations into a SDFG. :param node: The node to be translated @@ -386,12 +883,16 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if node.execution_part is None: return + if len(node.execution_part.execution) == 0: + return + + print("TRANSLATE SUBROUTINE", node.name.name) # First get the list of read and written variables inputnodefinder = ast_transforms.FindInputs() inputnodefinder.visit(node) input_vars = inputnodefinder.nodes - outputnodefinder = ast_transforms.FindOutputs() + outputnodefinder = ast_transforms.FindOutputs(thourough=True) outputnodefinder.visit(node) output_vars = outputnodefinder.nodes write_names = list(dict.fromkeys([i.name for i in output_vars])) @@ -399,9 +900,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, # Collect the parameters and the function signature to comnpare and link parameters = node.args.copy() + my_name_sdfg = node.name.name + str(self.sdfgs_count) + new_sdfg = SDFG(my_name_sdfg) + self.sdfgs_count += 1 + self.actual_offsets_per_sdfg[new_sdfg] = {} + self.names_of_object_in_parent_sdfg[new_sdfg] = {} + substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "state" + my_name_sdfg) - new_sdfg = SDFG(node.name.name) - substate = ast_utils.add_simple_state_to_sdfg(self, cfg, "state" + node.name.name) variables_in_call = [] if self.last_call_expression.get(sdfg) is not None: variables_in_call = self.last_call_expression[sdfg] @@ -410,10 +915,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if not ((len(variables_in_call) == len(parameters)) or (len(variables_in_call) == len(parameters) + 1 and not isinstance(node.result_type, ast_internal_classes.Void))): - for i in variables_in_call: - print("VAR CALL: ", i.name) - for j in parameters: - print("LOCAL TO UPDATE: ", j.name) + print("Subroutine", node.name.name) + print('Variables in call', len(variables_in_call)) + print('Parameters', len(parameters)) + # for i in variables_in_call: + # print("VAR CALL: ", i.name) + # for j in parameters: + # print("LOCAL TO UPDATE: ", j.name) raise ValueError("number of parameters does not match the function signature") # creating new arrays for nested sdfg @@ -427,15 +935,22 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, literals = [] literal_values = [] par2 = [] - + to_fix = [] symbol_arguments = [] # First we need to check if the parameters are literals or variables for arg_i, variable in enumerate(variables_in_call): if isinstance(variable, ast_internal_classes.Name_Node): varname = variable.name + elif isinstance(variable, ast_internal_classes.Actual_Arg_Spec_Node): + varname = variable.arg_name.name elif isinstance(variable, ast_internal_classes.Array_Subscript_Node): varname = variable.name.name + elif isinstance(variable, ast_internal_classes.Data_Ref_Node): + varname = ast_utils.get_name(variable) + elif isinstance(variable, ast_internal_classes.BinOp_Node): + varname = variable.rval.name + if isinstance(variable, ast_internal_classes.Literal) or varname == "LITERAL": literals.append(parameters[arg_i]) literal_values.append(variable) @@ -447,12 +962,14 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, par2.append(parameters[arg_i]) var2.append(variable) - #This handles the case where the function is called with literals + # This handles the case where the function is called with literals variables_in_call = var2 parameters = par2 assigns = [] + self.local_not_transient_because_assign[my_name_sdfg] = [] for lit, litval in zip(literals, literal_values): local_name = lit + self.local_not_transient_because_assign[my_name_sdfg].append(local_name.name) assigns.append( ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=local_name.name), rval=litval, @@ -462,6 +979,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, # This handles the case where the function is called with symbols for parameter, symbol in symbol_arguments: if parameter.name != symbol.name: + self.local_not_transient_because_assign[my_name_sdfg].append(parameter.name) assigns.append( ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=parameter.name), rval=ast_internal_classes.Name_Node(name=symbol.name), @@ -469,16 +987,22 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, line_number=node.line_number)) # This handles the case where the function is called with variables starting with the case that the variable is local to the calling SDFG + needs_replacement = {} + substate_sources = [] + substate_destinations = [] for variable_in_call in variables_in_call: all_arrays = self.get_arrays_in_context(sdfg) sdfg_name = self.name_mapping.get(sdfg).get(ast_utils.get_name(variable_in_call)) globalsdfg_name = self.name_mapping.get(self.globalsdfg).get(ast_utils.get_name(variable_in_call)) matched = False + view_ranges = {} for array_name, array in all_arrays.items(): + if array_name in [sdfg_name]: matched = True local_name = parameters[variables_in_call.index(variable_in_call)] + self.names_of_object_in_parent_sdfg[new_sdfg][local_name.name] = sdfg_name self.name_mapping[new_sdfg][local_name.name] = new_sdfg._find_new_name(local_name.name) self.all_array_names.append(self.name_mapping[new_sdfg][local_name.name]) if local_name.name in read_names: @@ -494,70 +1018,731 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, offsets = list(array.offset) mysize = 1 - if isinstance(variable_in_call, ast_internal_classes.Array_Subscript_Node): - changed_indices = 0 - for i in variable_in_call.indices: - if isinstance(i, ast_internal_classes.ParDecl_Node): - if i.type == "ALL": - shape.append(array.shape[indices]) - mysize = mysize * array.shape[indices] - index_list.append(None) + if isinstance(variable_in_call, ast_internal_classes.Data_Ref_Node): + done = False + bonus_step = False + depth = 0 + tmpvar = variable_in_call + local_name = parameters[variables_in_call.index(variable_in_call)] + top_structure_name = self.name_mapping[sdfg][ast_utils.get_name(tmpvar.parent_ref)] + top_structure = sdfg.arrays[top_structure_name] + current_parent_structure = top_structure + current_parent_structure_name = top_structure_name + name_chain = [top_structure_name] + while not done: + if isinstance(tmpvar.part_ref, ast_internal_classes.Data_Ref_Node): + + tmpvar = tmpvar.part_ref + depth += 1 + current_member_name = ast_utils.get_name(tmpvar.parent_ref) + if isinstance(tmpvar.parent_ref, ast_internal_classes.Array_Subscript_Node): + print("Array Subscript Node") + if bonus_step == True: + print("Bonus Step") + current_member = current_parent_structure.members[current_member_name] + concatenated_name = "_".join(name_chain) + local_shape = current_member.shape + new_shape = [] + local_indices = 0 + local_strides = list(current_member.strides) + local_offsets = list(current_member.offset) + local_index_list = [] + local_size = 1 + if isinstance(tmpvar.parent_ref, ast_internal_classes.Array_Subscript_Node): + changed_indices = 0 + for i in tmpvar.parent_ref.indices: + if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == "ALL": + new_shape.append(local_shape[local_indices]) + local_size = local_size * local_shape[local_indices] + local_index_list.append(None) + else: + raise NotImplementedError("Index in ParDecl should be ALL") + else: + + text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code( + i) + local_index_list.append(sym.pystr_to_symbolic(text)) + local_strides.pop(local_indices - changed_indices) + local_offsets.pop(local_indices - changed_indices) + changed_indices += 1 + local_indices = local_indices + 1 + local_all_indices = [None] * ( + len(local_shape) - len(local_index_list)) + local_index_list + if self.normalize_offsets: + subset = subs.Range([(i, i, 1) if i is not None else (0, s - 1, 1) + for i, s in zip(local_all_indices, local_shape)]) + else: + subset = subs.Range([(i, i, 1) if i is not None else (1, s, 1) + for i, s in zip(local_all_indices, local_shape)]) + smallsubset = subs.Range([(0, s - 1, 1) for s in new_shape]) + bonus_step = False + if isinstance(current_member, dat.ContainerArray): + if len(new_shape) == 0: + stype = current_member.stype + view_to_container = dat.View.view(current_member) + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)] = view_to_container + while isinstance(stype, dat.ContainerArray): + stype = stype.stype + bonus_step = True + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype) + view_to_member = dat.View.view(stype) + sdfg.arrays[concatenated_name + "_" + current_member_name + "_m_" + str( + self.struct_view_count)] = view_to_member + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.stype.dtype) else: - raise NotImplementedError("Index in ParDecl should be ALL") + view_to_member = dat.View.view(current_member) + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)] = view_to_member + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype,strides=current_member.strides,offset=current_member.offset) + + already_there_1 = False + already_there_2 = False + already_there_22 = False + already_there_3 = False + already_there_33 = False + already_there_4 = False + re = None + wv = None + wr = None + rv = None + wv2 = None + wr2 = None + if current_parent_structure_name == top_structure_name: + top_level = True + else: + top_level = False + if local_name.name in read_names: + + re, already_there_1 = find_access_in_sources(substate, substate_sources, + current_parent_structure_name) + wv, already_there_2 = find_access_in_destinations(substate, substate_destinations, + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)) + + if not bonus_step: + mem = Memlet.simple(current_parent_structure_name + "." + current_member_name, + subset) + substate.add_edge(re, None, wv, "views", dpcp(mem)) + else: + firstmem = Memlet.simple( + current_parent_structure_name + "." + current_member_name, + subs.Range.from_array(sdfg.arrays[ + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)])) + wv2, already_there_22 = find_access_in_destinations(substate, + substate_destinations, + concatenated_name + "_" + current_member_name + "_m_" + str( + self.struct_view_count)) + mem = Memlet.simple(concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count), subset) + substate.add_edge(re, None, wv, "views", dpcp(firstmem)) + substate.add_edge(wv, None, wv2, "views", dpcp(mem)) + + if local_name.name in write_names: + + wr, already_there_3 = find_access_in_destinations(substate, substate_destinations, + current_parent_structure_name) + rv, already_there_4 = find_access_in_sources(substate, substate_sources, + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)) + + if not bonus_step: + mem2 = Memlet.simple(current_parent_structure_name + "." + current_member_name, + subset) + substate.add_edge(rv, "views", wr, None, dpcp(mem2)) + else: + firstmem = Memlet.simple( + current_parent_structure_name + "." + current_member_name, + subs.Range.from_array(sdfg.arrays[ + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)])) + wr2, already_there_33 = find_access_in_sources(substate, substate_sources, + concatenated_name + "_" + current_member_name + "_m_" + str( + self.struct_view_count)) + mem2 = Memlet.simple(concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count), subset) + substate.add_edge(wr2, "views", rv, None, dpcp(mem2)) + substate.add_edge(rv, "views", wr, None, dpcp(firstmem)) + + if not already_there_1: + if re is not None: + if not top_level: + substate_sources.append(re) + else: + substate_destinations.append(re) + + if not already_there_2: + if wv is not None: + substate_destinations.append(wv) + + if not already_there_3: + if wr is not None: + if not top_level: + substate_destinations.append(wr) + else: + substate_sources.append(wr) + if not already_there_4: + if rv is not None: + substate_sources.append(rv) + + if bonus_step == True: + if not already_there_22: + if wv2 is not None: + substate_destinations.append(wv2) + if not already_there_33: + if wr2 is not None: + substate_sources.append(wr2) + + if not bonus_step: + current_parent_structure_name = concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count) + else: + current_parent_structure_name = concatenated_name + "_" + current_member_name + "_m_" + str( + self.struct_view_count) + current_parent_structure = current_parent_structure.members[current_member_name] + self.struct_view_count += 1 + name_chain.append(current_member_name) else: - text = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(i) - index_list.append(sym.pystr_to_symbolic(text)) - strides.pop(indices - changed_indices) - offsets.pop(indices - changed_indices) - changed_indices += 1 - indices = indices + 1 - - if isinstance(variable_in_call, ast_internal_classes.Name_Node): - shape = list(array.shape) - # Functionally, this identifies the case where the array is in fact a scalar - if shape == () or shape == (1, ) or shape == [] or shape == [1]: - new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, array.storage) + done = True + tmpvar = tmpvar.part_ref + concatenated_name = "_".join(name_chain) + array_name = ast_utils.get_name(tmpvar) + member_name = ast_utils.get_name(tmpvar) + if bonus_step == True: + print("Bonus Step") + last_view_name = concatenated_name + "_m_" + str(self.struct_view_count - 1) + else: + if depth > 0: + last_view_name = concatenated_name + "_" + str(self.struct_view_count - 1) + else: + last_view_name = concatenated_name + if isinstance(current_parent_structure, dat.ContainerArray): + stype = current_parent_structure.stype + while isinstance(stype, dat.ContainerArray): + stype = stype.stype + + array = stype.members[ast_utils.get_name(tmpvar)] + + else: + array = current_parent_structure.members[ast_utils.get_name(tmpvar)] # FLAG + + if isinstance(array, dat.ContainerArray): + view_to_member = dat.View.view(array) + sdfg.arrays[concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count)] = view_to_member + + else: + view_to_member = dat.View.view(array) + sdfg.arrays[concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count)] = view_to_member + + # sdfg.add_view(concatenated_name+"_"+array_name+"_"+str(self.struct_view_count),array.shape,array.dtype,strides=array.strides,offset=array.offset) + last_view_name_read = None + re = None + wv = None + wr = None + rv = None + already_there_1 = False + already_there_2 = False + already_there_3 = False + already_there_4 = False + if current_parent_structure_name == top_structure_name: + top_level = True + else: + top_level = False + if local_name.name in read_names: + for i in substate_destinations: + if i.data == last_view_name: + re = i + already_there_1 = True + break + if not already_there_1: + re = substate.add_read(last_view_name) + + for i in substate_sources: + if i.data == concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count): + wv = i + already_there_2 = True + break + if not already_there_2: + wv = substate.add_write( + concatenated_name + "_" + array_name + "_" + str(self.struct_view_count)) + + mem = Memlet.from_array(last_view_name + "." + member_name, array) + substate.add_edge(re, None, wv, "views", dpcp(mem)) + last_view_name_read = concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count) + last_view_name_write = None + if local_name.name in write_names: + for i in substate_sources: + if i.data == last_view_name: + wr = i + already_there_3 = True + break + if not already_there_3: + wr = substate.add_write(last_view_name) + for i in substate_destinations: + if i.data == concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count): + rv = i + already_there_4 = True + break + if not already_there_4: + rv = substate.add_read( + concatenated_name + "_" + array_name + "_" + str(self.struct_view_count)) + + mem2 = Memlet.from_array(last_view_name + "." + member_name, array) + substate.add_edge(rv, "views", wr, None, dpcp(mem2)) + last_view_name_write = concatenated_name + "_" + array_name + "_" + str( + self.struct_view_count) + if not already_there_1: + if re is not None: + if not top_level: + substate_sources.append(re) + else: + substate_destinations.append(re) + if not already_there_2: + if wv is not None: + substate_destinations.append(wv) + if not already_there_3: + if wr is not None: + if not top_level: + substate_destinations.append(wr) + else: + substate_sources.append(wr) + if not already_there_4: + if rv is not None: + substate_sources.append(rv) + mapped_name_overwrite = concatenated_name + "_" + array_name + self.views = self.views + 1 + views.append([mapped_name_overwrite, wv, rv, variables_in_call.index(variable_in_call)]) + + if last_view_name_write is not None and last_view_name_read is not None: + if last_view_name_read != last_view_name_write: + raise NotImplementedError("Read and write views should be the same") + else: + last_view_name = last_view_name_read + if last_view_name_read is not None and last_view_name_write is None: + last_view_name = last_view_name_read + if last_view_name_write is not None and last_view_name_read is None: + last_view_name = last_view_name_write + mapped_name_overwrite = concatenated_name + "_" + array_name + strides = list(array.strides) + offsets = list(array.offset) + self.struct_view_count += 1 + + if isinstance(array, dat.ContainerArray) and isinstance(tmpvar, + ast_internal_classes.Array_Subscript_Node): + current_member_name = ast_utils.get_name(tmpvar) + current_member = current_parent_structure.members[current_member_name] + concatenated_name = "_".join(name_chain) + local_shape = current_member.shape + new_shape = [] + local_indices = 0 + local_strides = list(current_member.strides) + local_offsets = list(current_member.offset) + local_index_list = [] + local_size = 1 + changed_indices = 0 + for i in tmpvar.indices: + if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == "ALL": + new_shape.append(local_shape[local_indices]) + local_size = local_size * local_shape[local_indices] + local_index_list.append(None) + else: + raise NotImplementedError("Index in ParDecl should be ALL") + else: + text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code( + i) + local_index_list.append(sym.pystr_to_symbolic(text)) + local_strides.pop(local_indices - changed_indices) + local_offsets.pop(local_indices - changed_indices) + changed_indices += 1 + local_indices = local_indices + 1 + local_all_indices = [None] * ( + len(local_shape) - len(local_index_list)) + local_index_list + if self.normalize_offsets: + subset = subs.Range([(i, i, 1) if i is not None else (0, s - 1, 1) + for i, s in zip(local_all_indices, local_shape)]) + else: + subset = subs.Range([(i, i, 1) if i is not None else (1, s, 1) + for i, s in zip(local_all_indices, local_shape)]) + smallsubset = subs.Range([(0, s - 1, 1) for s in new_shape]) + if isinstance(current_member, dat.ContainerArray): + if len(new_shape) == 0: + stype = current_member.stype + while isinstance(stype, dat.ContainerArray): + stype = stype.stype + bonus_step = True + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype) + view_to_member = dat.View.view(stype) + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)] = view_to_member + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.stype.dtype) + else: + view_to_member = dat.View.view(current_member) + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)] = view_to_member + + # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype,strides=current_member.strides,offset=current_member.offset) + already_there_1 = False + already_there_2 = False + already_there_3 = False + already_there_4 = False + re = None + wv = None + wr = None + rv = None + if current_parent_structure_name == top_structure_name: + top_level = True + else: + top_level = False + if local_name.name in read_names: + for i in substate_destinations: + if i.data == last_view_name: + re = i + already_there_1 = True + break + if not already_there_1: + re = substate.add_read(last_view_name) + + for i in substate_sources: + if i.data == concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count): + wv = i + already_there_2 = True + break + if not already_there_2: + wv = substate.add_write( + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)) + + if isinstance(current_member, dat.ContainerArray): + mem = Memlet.simple(last_view_name, subset) + else: + mem = Memlet.simple( + current_parent_structure_name + "." + current_member_name, subset) + substate.add_edge(re, None, wv, "views", dpcp(mem)) + + if local_name.name in write_names: + for i in substate_sources: + if i.data == last_view_name: + wr = i + already_there_3 = True + break + if not already_there_3: + wr = substate.add_write(last_view_name) + + for i in substate_destinations: + if i.data == concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count): + rv = i + already_there_4 = True + break + if not already_there_4: + rv = substate.add_read( + concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count)) + + if isinstance(current_member, dat.ContainerArray): + mem2 = Memlet.simple(last_view_name, subset) + else: + mem2 = Memlet.simple( + current_parent_structure_name + "." + current_member_name, subset) + + substate.add_edge(rv, "views", wr, None, dpcp(mem2)) + if not already_there_1: + if re is not None: + if not top_level: + substate_sources.append(re) + else: + substate_destinations.append(re) + if not already_there_2: + if wv is not None: + substate_destinations.append(wv) + if not already_there_3: + if wr is not None: + if not top_level: + substate_destinations.append(wr) + else: + substate_sources.append(wr) + if not already_there_4: + if rv is not None: + substate_sources.append(rv) + last_view_name = concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count) + if not isinstance(current_member, dat.ContainerArray): + mapped_name_overwrite = concatenated_name + "_" + current_member_name + needs_replacement[mapped_name_overwrite] = last_view_name + else: + mapped_name_overwrite = concatenated_name + "_" + current_member_name + needs_replacement[mapped_name_overwrite] = last_view_name + mapped_name_overwrite = concatenated_name + "_" + current_member_name + "_" + str( + self.struct_view_count) + self.views = self.views + 1 + views.append( + [mapped_name_overwrite, wv, rv, variables_in_call.index(variable_in_call)]) + + strides = list(view_to_member.strides) + offsets = list(view_to_member.offset) + self.struct_view_count += 1 + + if isinstance(tmpvar, ast_internal_classes.Array_Subscript_Node): + + changed_indices = 0 + for i in tmpvar.indices: + if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == "ALL": + shape.append(array.shape[indices]) + mysize = mysize * array.shape[indices] + index_list.append(None) + else: + start = i.range[0] + stop = i.range[1] + text_start = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(start) + text_stop = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(stop) + shape.append("( "+ text_stop + ") - ( "+ text_start + ") ") + mysize=mysize*sym.pystr_to_symbolic("( "+ text_stop + ") - ( "+ text_start + ") ") + index_list.append(None) + # raise NotImplementedError("Index in ParDecl should be ALL") + else: + text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) + index_list.append(sym.pystr_to_symbolic(text)) + strides.pop(indices - changed_indices) + offsets.pop(indices - changed_indices) + changed_indices += 1 + indices = indices + 1 + + + + elif isinstance(tmpvar, ast_internal_classes.Name_Node): + shape = list(array.shape) + else: + raise NotImplementedError("Unknown part_ref type") + + if shape == () or shape == (1,) or shape == [] or shape == [1]: + # FIXME 6.03.2024 + # print(array,array.__class__.__name__) + if isinstance(array, dat.ContainerArray): + if isinstance(array.stype, dat.ContainerArray): + if isinstance(array.stype.stype, dat.Structure): + element_type = array.stype.stype + else: + element_type = array.stype.stype.dtype + + elif isinstance(array.stype, dat.Structure): + element_type = array.stype + else: + element_type = array.stype.dtype + # print(element_type,element_type.__class__.__name__) + # print(array.base_type,array.base_type.__class__.__name__) + elif isinstance(array, dat.Structure): + element_type = array + elif isinstance(array, pointer): + if hasattr(array, "stype"): + if hasattr(array.stype, "free_symbols"): + element_type = array.stype + # print("get stype") + elif isinstance(array, dat.Array): + element_type = array.dtype + elif isinstance(array, dat.Scalar): + element_type = array.dtype + + else: + if hasattr(array, "dtype"): + if hasattr(array.dtype, "free_symbols"): + element_type = array.dtype + # print("get dtype") + + if isinstance(element_type, pointer): + # print("pointer-ized") + found = False + if hasattr(element_type, "dtype"): + if hasattr(element_type.dtype, "free_symbols"): + element_type = element_type.dtype + found = True + # print("get dtype") + if hasattr(element_type, "stype"): + if hasattr(element_type.stype, "free_symbols"): + element_type = element_type.stype + found = True + # print("get stype") + if hasattr(element_type, "base_type"): + if hasattr(element_type.base_type, "free_symbols"): + element_type = element_type.base_type + found = True + # print("get base_type") + # if not found: + # print(dir(element_type)) + # print("array info: "+str(array),array.__class__.__name__) + # print(element_type,element_type.__class__.__name__) + if hasattr(element_type, "name") and element_type.name in self.registered_types: + datatype = self.get_dace_type(element_type.name) + datatype_to_add = copy.deepcopy(datatype) + datatype_to_add.transient = False + # print(datatype_to_add,datatype_to_add.__class__.__name__) + new_sdfg.add_datadesc(self.name_mapping[new_sdfg][local_name.name], datatype_to_add) + + if self.struct_views.get(new_sdfg) is None: + self.struct_views[new_sdfg] = {} + + add_views_recursive(new_sdfg, local_name.name, datatype_to_add, + self.struct_views[new_sdfg], self.name_mapping[new_sdfg], + self.registered_types, [], self.actual_offsets_per_sdfg[new_sdfg], + self.names_of_object_in_parent_sdfg[new_sdfg], + self.actual_offsets_per_sdfg[sdfg]) + + else: + new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, + array.storage) + else: + element_type = array.dtype.base_type + if element_type in self.registered_types: + raise NotImplementedError("Nested derived types not implemented") + datatype_to_add = copy.deepcopy(element_type) + datatype_to_add.transient = False + new_sdfg.add_datadesc(self.name_mapping[new_sdfg][local_name.name], datatype_to_add) + # arr_dtype = datatype[sizes] + # arr_dtype.offset = [offset_value for _ in sizes] + # sdfg.add_datadesc(self.name_mapping[sdfg][node.name], arr_dtype) + else: + + new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], + shape, + array.dtype, + array.storage, + strides=strides, + offset=offsets) else: - # This is the case where the array is not a scalar and we need to create a view - if not isinstance(variable_in_call, ast_internal_classes.Name_Node): - offsets_zero = [] - for index in offsets: - offsets_zero.append(0) - viewname, view = sdfg.add_view(array_name + "_view_" + str(self.views), - shape, - array.dtype, - storage=array.storage, - strides=strides, - offset=offsets_zero) - from dace import subsets - - all_indices = [None] * (len(array.shape) - len(index_list)) + index_list - subset = subsets.Range([(i, i, 1) if i is not None else (1, s, 1) - for i, s in zip(all_indices, array.shape)]) - smallsubset = subsets.Range([(0, s - 1, 1) for s in shape]) - - memlet = Memlet(f'{array_name}[{subset}]->[{smallsubset}]') - memlet2 = Memlet(f'{viewname}[{smallsubset}]->[{subset}]') - wv = None - rv = None - if local_name.name in read_names: - r = substate.add_read(array_name) - wv = substate.add_write(viewname) - substate.add_edge(r, None, wv, 'views', dpcp(memlet)) - if local_name.name in write_names: - rv = substate.add_read(viewname) - w = substate.add_write(array_name) - substate.add_edge(rv, 'views2', w, None, dpcp(memlet2)) - - self.views = self.views + 1 - views.append([array_name, wv, rv, variables_in_call.index(variable_in_call)]) - - new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], - shape, - array.dtype, - array.storage, - strides=strides, - offset=offsets) + + if isinstance(variable_in_call, ast_internal_classes.Array_Subscript_Node): + changed_indices = 0 + for i in variable_in_call.indices: + if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == "ALL": + shape.append(array.shape[indices]) + mysize = mysize * array.shape[indices] + index_list.append(None) + else: + start = i.range[0] + stop = i.range[1] + text_start = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code( + start) + text_stop = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code( + stop) + symb_size = sym.pystr_to_symbolic(text_stop + " - ( " + text_start + " )") + shape.append(symb_size) + mysize = mysize * symb_size + index_list.append( + [sym.pystr_to_symbolic(text_start), sym.pystr_to_symbolic(text_stop)]) + # raise NotImplementedError("Index in ParDecl should be ALL") + else: + text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) + index_list.append([sym.pystr_to_symbolic(text), sym.pystr_to_symbolic(text)]) + strides.pop(indices - changed_indices) + offsets.pop(indices - changed_indices) + changed_indices += 1 + indices = indices + 1 + + if isinstance(variable_in_call, ast_internal_classes.Name_Node): + shape = list(array.shape) + + # print("Data_Ref_Node") + # Functionally, this identifies the case where the array is in fact a scalar + if shape == () or shape == (1,) or shape == [] or shape == [1]: + if hasattr(array, "name") and array.name in self.registered_types: + datatype = self.get_dace_type(array.name) + datatype_to_add = copy.deepcopy(array) + datatype_to_add.transient = False + new_sdfg.add_datadesc(self.name_mapping[new_sdfg][local_name.name], datatype_to_add) + + if self.struct_views.get(new_sdfg) is None: + self.struct_views[new_sdfg] = {} + add_views_recursive(new_sdfg, local_name.name, datatype_to_add, + self.struct_views[new_sdfg], self.name_mapping[new_sdfg], + self.registered_types, [], self.actual_offsets_per_sdfg[new_sdfg], + self.names_of_object_in_parent_sdfg[new_sdfg], + self.actual_offsets_per_sdfg[sdfg]) + + else: + new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, + array.storage) + else: + # This is the case where the array is not a scalar and we need to create a view + if not (shape == () or shape == (1,) or shape == [] or shape == [1]): + offsets_zero = [] + for index in offsets: + offsets_zero.append(0) + viewname, view = sdfg.add_view(array_name + "_view_" + str(self.views), + shape, + array.dtype, + storage=array.storage, + strides=strides, + offset=offsets_zero) + from dace import subsets + + all_indices = [None] * (len(array.shape) - len(index_list)) + index_list + if self.normalize_offsets: + subset = subsets.Range([(i[0] - 1, i[1] - 1, 1) if i is not None else (0, s - 1, 1) + for i, s in zip(all_indices, array.shape)]) + else: + subset = subsets.Range([(i[0], i[1], 1) if i is not None else (1, s, 1) + for i, s in zip(all_indices, array.shape)]) + smallsubset = subsets.Range([(0, s - 1, 1) for s in shape]) + + # memlet = Memlet(f'{array_name}[{subset}]->{smallsubset}') + # memlet2 = Memlet(f'{viewname}[{smallsubset}]->{subset}') + memlet = Memlet(f'{array_name}[{subset}]') + memlet2 = Memlet(f'{array_name}[{subset}]') + wv = None + rv = None + if local_name.name in read_names: + r = substate.add_read(array_name) + wv = substate.add_write(viewname) + substate.add_edge(r, None, wv, 'views', dpcp(memlet)) + if local_name.name in write_names: + rv = substate.add_read(viewname) + w = substate.add_write(array_name) + substate.add_edge(rv, 'views', w, None, dpcp(memlet2)) + + self.views = self.views + 1 + views.append([array_name, wv, rv, variables_in_call.index(variable_in_call)]) + + new_sdfg.add_array(self.name_mapping[new_sdfg][local_name.name], + shape, + array.dtype, + array.storage, + strides=strides, + offset=offsets) + if not matched: # This handles the case where the function is called with global variables for array_name, array in all_arrays.items(): @@ -576,7 +1761,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, shape = array.shape[indices:] - if shape == () or shape == (1, ): + if shape == () or shape == (1,): new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, array.storage) else: @@ -612,10 +1797,21 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, outs_in_new_sdfg.append(self.name_mapping[new_sdfg][i]) new_sdfg.add_scalar(self.name_mapping[new_sdfg][i], dtypes.int32, transient=False) addedmemlets = [] + globalmemlets = [] + names_list = [] + if node.specification_part is not None: + if node.specification_part.specifications is not None: + namefinder = ast_transforms.FindDefinedNames() + for i in node.specification_part.specifications: + namefinder.visit(i) + names_list = namefinder.names # This handles the case where the function is called with read variables found in a module + cached_names=[a[0] for a in self.module_vars] for i in not_found_read_names: - if i in [a[0] for a in self.module_vars]: + if i in names_list: + continue + if i in cached_names: if self.name_mapping[sdfg].get(i) is not None: self.name_mapping[new_sdfg][i] = new_sdfg._find_new_name(i) addedmemlets.append(i) @@ -627,7 +1823,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, array_in_global = sdfg.arrays[self.name_mapping[sdfg][i]] if isinstance(array_in_global, Scalar): new_sdfg.add_scalar(self.name_mapping[new_sdfg][i], array_in_global.dtype, transient=False) - elif array_in_global.type == "Array": + elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( + array_in_global, dat.Array): new_sdfg.add_array(self.name_mapping[new_sdfg][i], array_in_global.shape, array_in_global.dtype, @@ -647,7 +1844,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, array_in_global = self.globalsdfg.arrays[self.name_mapping[self.globalsdfg][i]] if isinstance(array_in_global, Scalar): new_sdfg.add_scalar(self.name_mapping[new_sdfg][i], array_in_global.dtype, transient=False) - elif array_in_global.type == "Array": + elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( + array_in_global, dat.Array): new_sdfg.add_array(self.name_mapping[new_sdfg][i], array_in_global.shape, array_in_global.dtype, @@ -659,6 +1857,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, for i in not_found_write_names: if i in not_found_read_names: continue + if i in names_list: + continue if i in [a[0] for a in self.module_vars]: if self.name_mapping[sdfg].get(i) is not None: self.name_mapping[new_sdfg][i] = new_sdfg._find_new_name(i) @@ -672,7 +1872,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, array = sdfg.arrays[self.name_mapping[sdfg][i]] if isinstance(array_in_global, Scalar): new_sdfg.add_scalar(self.name_mapping[new_sdfg][i], array_in_global.dtype, transient=False) - elif array_in_global.type == "Array": + elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( + array_in_global, dat.Array): new_sdfg.add_array(self.name_mapping[new_sdfg][i], array_in_global.shape, array_in_global.dtype, @@ -692,7 +1893,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, array = self.globalsdfg.arrays[self.name_mapping[self.globalsdfg][i]] if isinstance(array_in_global, Scalar): new_sdfg.add_scalar(self.name_mapping[new_sdfg][i], array_in_global.dtype, transient=False) - elif array_in_global.type == "Array": + elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( + array_in_global, dat.Array): new_sdfg.add_array(self.name_mapping[new_sdfg][i], array_in_global.shape, array_in_global.dtype, @@ -700,14 +1902,31 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, transient=False, strides=array_in_global.strides, offset=array_in_global.offset) + all_symbols = new_sdfg.free_symbols + missing_symbols = [s for s in all_symbols if s not in sym_dict] + for i in missing_symbols: + if i in sdfg.arrays: + sym_dict[i] = i + else: + print("Symbol not found in sdfg arrays: ", i) + if self.multiple_sdfgs == False: + # print("Adding nested sdfg", new_sdfg.name, "to", sdfg.name) + # print(sym_dict) + internal_sdfg = substate.add_nested_sdfg(new_sdfg, + sdfg, + ins_in_new_sdfg, + outs_in_new_sdfg, + symbol_mapping=sym_dict) + else: + internal_sdfg = substate.add_nested_sdfg(None, + sdfg, + ins_in_new_sdfg, + outs_in_new_sdfg, + symbol_mapping=sym_dict, + name="External_nested_" + new_sdfg.name) + # if self.multiple_sdfgs==False: + # Now adding memlets - internal_sdfg = substate.add_nested_sdfg(new_sdfg, - sdfg, - ins_in_new_sdfg, - outs_in_new_sdfg, - symbol_mapping=sym_dict) - - # Now adding memlets for i in self.libstates: memlet = "0" if i in write_names: @@ -723,6 +1942,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if self.name_mapping.get(sdfg).get(ast_utils.get_name(i)) is not None: var = sdfg.arrays.get(self.name_mapping[sdfg][ast_utils.get_name(i)]) mapped_name = self.name_mapping[sdfg][ast_utils.get_name(i)] + if needs_replacement.get(mapped_name) is not None: + mapped_name = needs_replacement[mapped_name] + var = sdfg.arrays[mapped_name] # TODO: FIx symbols in function calls elif ast_utils.get_name(i) in sdfg.symbols: var = ast_utils.get_name(i) @@ -738,7 +1960,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, elif (len(var.shape) == 1 and var.shape[0] == 1): memlet = "0" else: - memlet = ast_utils.generate_memlet(i, sdfg, self) + memlet = ast_utils.generate_memlet(i, sdfg, self, self.normalize_offsets) found = False for elem in views: @@ -746,17 +1968,19 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, found = True if local_name.name in write_names: - memlet = subsets.Range([(0, s - 1, 1) for s in sdfg.arrays[elem[2].label].shape]) + memlet = subs.Range([(0, s - 1, 1) for s in sdfg.arrays[elem[2].label].shape]) substate.add_memlet_path(internal_sdfg, elem[2], src_conn=self.name_mapping[new_sdfg][local_name.name], memlet=Memlet(expr=elem[2].label, subset=memlet)) if local_name.name in read_names: - memlet = subsets.Range([(0, s - 1, 1) for s in sdfg.arrays[elem[1].label].shape]) + memlet = subs.Range([(0, s - 1, 1) for s in sdfg.arrays[elem[1].label].shape]) substate.add_memlet_path(elem[1], internal_sdfg, dst_conn=self.name_mapping[new_sdfg][local_name.name], memlet=Memlet(expr=elem[1].label, subset=memlet)) + if found: + break if not found: if local_name.name in write_names: @@ -767,8 +1991,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.name_mapping[new_sdfg][local_name.name], memlet) for i in addedmemlets: - - memlet = ast_utils.generate_memlet(ast_internal_classes.Name_Node(name=i), sdfg, self) + local_name = ast_internal_classes.Name_Node(name=i) + memlet = ast_utils.generate_memlet(ast_internal_classes.Name_Node(name=i), sdfg, self, + self.normalize_offsets) if local_name.name in write_names: ast_utils.add_memlet_write(substate, self.name_mapping[sdfg][i], internal_sdfg, self.name_mapping[new_sdfg][i], memlet) @@ -776,36 +2001,91 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, ast_utils.add_memlet_read(substate, self.name_mapping[sdfg][i], internal_sdfg, self.name_mapping[new_sdfg][i], memlet) for i in globalmemlets: + local_name = ast_internal_classes.Name_Node(name=i) + found = False + parent_sdfg = sdfg + nested_sdfg = new_sdfg + first = True + while not found and parent_sdfg is not None: + if self.name_mapping.get(parent_sdfg).get(i) is not None: + found = True + else: + self.name_mapping[parent_sdfg][i] = parent_sdfg._find_new_name(i) + self.all_array_names.append(self.name_mapping[parent_sdfg][i]) + array_in_global = self.globalsdfg.arrays[self.name_mapping[self.globalsdfg][i]] + if isinstance(array_in_global, Scalar): + parent_sdfg.add_scalar(self.name_mapping[parent_sdfg][i], array_in_global.dtype, + transient=False) + elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( + array_in_global, dat.Array): + parent_sdfg.add_array(self.name_mapping[parent_sdfg][i], + array_in_global.shape, + array_in_global.dtype, + array_in_global.storage, + transient=False, + strides=array_in_global.strides, + offset=array_in_global.offset) + + if first: + first = False + else: + if local_name.name in write_names: + nested_sdfg.parent_nsdfg_node.add_out_connector(self.name_mapping[parent_sdfg][i], force=True) + if local_name.name in read_names: + nested_sdfg.parent_nsdfg_node.add_in_connector(self.name_mapping[parent_sdfg][i], force=True) - memlet = ast_utils.generate_memlet(ast_internal_classes.Name_Node(name=i), sdfg, self) - if local_name.name in write_names: - ast_utils.add_memlet_write(substate, self.name_mapping[self.globalsdfg][i], internal_sdfg, - self.name_mapping[new_sdfg][i], memlet) - if local_name.name in read_names: - ast_utils.add_memlet_read(substate, self.name_mapping[self.globalsdfg][i], internal_sdfg, - self.name_mapping[new_sdfg][i], memlet) + memlet = ast_utils.generate_memlet(ast_internal_classes.Name_Node(name=i), parent_sdfg, self, + self.normalize_offsets) + if local_name.name in write_names: + ast_utils.add_memlet_write(nested_sdfg.parent, self.name_mapping[parent_sdfg][i], + nested_sdfg.parent_nsdfg_node, + self.name_mapping[nested_sdfg][i], memlet) + if local_name.name in read_names: + ast_utils.add_memlet_read(nested_sdfg.parent, self.name_mapping[parent_sdfg][i], + nested_sdfg.parent_nsdfg_node, + self.name_mapping[nested_sdfg][i], memlet) + if not found: + nested_sdfg = parent_sdfg + parent_sdfg = parent_sdfg.parent_sdfg + + if self.multiple_sdfgs == False: + if node.execution_part is not None: + if node.specification_part is not None and node.specification_part.uses is not None: + for j in node.specification_part.uses: + for k in j.list: + if self.contexts.get(new_sdfg.name) is None: + self.contexts[new_sdfg.name] = ast_utils.Context(name=new_sdfg.name) + if self.contexts[new_sdfg.name].constants.get( + ast_utils.get_name(k)) is None and self.contexts[ + self.globalsdfg.name].constants.get( + ast_utils.get_name(k)) is not None: + self.contexts[new_sdfg.name].constants[ast_utils.get_name(k)] = self.contexts[ + self.globalsdfg.name].constants[ast_utils.get_name(k)] - #Finally, now that the nested sdfg is built and the memlets are added, we can parse the internal of the subroutine and add it to the SDFG. + pass - if node.execution_part is not None: - for j in node.specification_part.uses: - for k in j.list: - if self.contexts.get(new_sdfg.name) is None: - self.contexts[new_sdfg.name] = ast_utils.Context(name=new_sdfg.name) - if self.contexts[new_sdfg.name].constants.get( - ast_utils.get_name(k)) is None and self.contexts[self.globalsdfg.name].constants.get( - ast_utils.get_name(k)) is not None: - self.contexts[new_sdfg.name].constants[ast_utils.get_name(k)] = self.contexts[ - self.globalsdfg.name].constants[ast_utils.get_name(k)] + old_mode = self.transient_mode + # print("For ",sdfg_name," old mode is ",old_mode) + self.transient_mode = True + for j in node.specification_part.symbols: + if isinstance(j, ast_internal_classes.Symbol_Decl_Node): + self.symbol2sdfg(j, new_sdfg) + else: + raise NotImplementedError("Symbol not implemented") + + for j in node.specification_part.specifications: + self.declstmt2sdfg(j, new_sdfg) + self.transient_mode = old_mode + + for i in assigns: + self.translate(i, new_sdfg) + self.translate(node.execution_part, new_sdfg) - pass - for j in node.specification_part.specifications: - self.declstmt2sdfg(j, new_sdfg, new_sdfg) - for i in assigns: - self.translate(i, new_sdfg, new_sdfg) - self.translate(node.execution_part, new_sdfg, new_sdfg) + if self.multiple_sdfgs == True: + internal_sdfg.path = self.sdfg_path + new_sdfg.name + ".sdfg" + # new_sdfg.save(path.join(self.sdfg_path, new_sdfg.name + ".sdfg")) - def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): """ This parses binary operations to tasklets in a new state or creates a function call with a nested SDFG if the operation is a function @@ -819,13 +2099,14 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con if len(calls.nodes) == 1: augmented_call = calls.nodes[0] from dace.frontend.fortran.intrinsics import FortranIntrinsics - if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", *FortranIntrinsics.retained_function_names()]: + if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", + *FortranIntrinsics.retained_function_names()]: augmented_call.args.append(node.lval) augmented_call.hasret = True - self.call2sdfg(augmented_call, sdfg, cfg) + self.call2sdfg(augmented_call, sdfg) return - outputnodefinder = ast_transforms.FindOutputs() + outputnodefinder = ast_transforms.FindOutputs(thourough=False) outputnodefinder.visit(node) output_vars = outputnodefinder.nodes output_names = [] @@ -856,7 +2137,7 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con input_names_tasklet.append(i.name + "_" + str(count) + "_in") substate = ast_utils.add_simple_state_to_sdfg( - self, cfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) + self, sdfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) output_names_changed = [o_t + "_out" for o_t in output_names] @@ -866,19 +2147,33 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con for i, j in zip(input_names, input_names_tasklet): memlet_range = self.get_memlet_range(sdfg, input_vars, i, j) - ast_utils.add_memlet_read(substate, i, tasklet, j, memlet_range) + src = ast_utils.add_memlet_read(substate, i, tasklet, j, memlet_range) + # if self.struct_views.get(sdfg) is not None: + # if self.struct_views[sdfg].get(i) is not None: + # chain= self.struct_views[sdfg][i] + # access_parent=substate.add_access(chain[0]) + # name=chain[0] + # for i in range(1,len(chain)): + # view_name=name+"_"+chain[i] + # access_child=substate.add_access(view_name) + # substate.add_edge(access_parent, None,access_child, 'views', Memlet.simple(name+"."+chain[i],subs.Range.from_array(sdfg.arrays[view_name]))) + # name=view_name + # access_parent=access_child + + # substate.add_edge(access_parent, None,src,'views', Memlet(data=name, subset=memlet_range)) for i, j, k in zip(output_names, output_names_tasklet, output_names_changed): - memlet_range = self.get_memlet_range(sdfg, output_vars, i, j) ast_utils.add_memlet_write(substate, i, tasklet, k, memlet_range) tw = ast_utils.TaskletWriter(output_names, output_names_changed, sdfg, self.name_mapping, input_names, - input_names_tasklet) + input_names_tasklet, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) text = tw.write_code(node) + # print(sdfg.name,node.line_number,output_names,output_names_changed,input_names,input_names_tasklet) tasklet.code = CodeBlock(text, lang.Python) - def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): """ This parses function calls to a nested SDFG or creates a tasklet with an external library call. @@ -890,24 +2185,26 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: match_found = False rettype = "INTEGER" hasret = False - if node.name in self.functions_and_subroutines: - for i in self.top_level.function_definitions: - if i.name == node.name: - self.function2sdfg(i, sdfg, cfg) - return - for i in self.top_level.subroutine_definitions: - if i.name == node.name: - self.subroutine2sdfg(i, sdfg, cfg) - return - for j in self.top_level.modules: - for i in j.function_definitions: - if i.name == node.name: - self.function2sdfg(i, sdfg, cfg) + for fsname in self.functions_and_subroutines: + if fsname.name == node.name.name: + + for i in self.top_level.function_definitions: + if i.name.name == node.name.name: + self.function2sdfg(i, sdfg) return - for i in j.subroutine_definitions: - if i.name == node.name: - self.subroutine2sdfg(i, sdfg, cfg) + for i in self.top_level.subroutine_definitions: + if i.name.name == node.name.name: + self.subroutine2sdfg(i, sdfg) return + for j in self.top_level.modules: + for i in j.function_definitions: + if i.name.name == node.name.name: + self.function2sdfg(i, sdfg) + return + for i in j.subroutine_definitions: + if i.name.name == node.name.name: + self.subroutine2sdfg(i, sdfg) + return else: # This part handles the case that it's an external library call libstate = self.libraries.get(node.name.name) @@ -952,16 +2249,24 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: output_names_changed.append(o_t + "_out") tw = ast_utils.TaskletWriter(output_names_tasklet.copy(), output_names_changed.copy(), sdfg, - self.name_mapping) + self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) if not isinstance(rettype, ast_internal_classes.Void) and hasret: - special_list_in[retval.name] = pointer(self.get_dace_type(rettype)) - special_list_out.append(retval.name + "_out") + if isinstance(retval, ast_internal_classes.Name_Node): + special_list_in[retval.name] = pointer(self.get_dace_type(rettype)) + special_list_out.append(retval.name + "_out") + elif isinstance(retval, ast_internal_classes.Array_Subscript_Node): + special_list_in[retval.name.name] = pointer(self.get_dace_type(rettype)) + special_list_out.append(retval.name.name + "_out") + else: + raise NotImplementedError("Return type not implemented") + text = tw.write_code( ast_internal_classes.BinOp_Node(lval=retval, op="=", rval=node, line_number=node.line_number)) else: text = tw.write_code(node) - substate = ast_utils.add_simple_state_to_sdfg(self, cfg, "_state" + str(node.line_number[0])) + substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "_state" + str(node.line_number[0])) tasklet = ast_utils.add_tasklet(substate, str(node.line_number[0]), { **input_names_tasklet, @@ -974,23 +2279,29 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: ast_utils.add_memlet_write(substate, self.name_mapping[sdfg][libstate], tasklet, self.name_mapping[sdfg][libstate] + "_task_out", "0") if not isinstance(rettype, ast_internal_classes.Void) and hasret: - ast_utils.add_memlet_read(substate, self.name_mapping[sdfg][retval.name], tasklet, retval.name, "0") + if isinstance(retval, ast_internal_classes.Name_Node): + ast_utils.add_memlet_read(substate, self.name_mapping[sdfg][retval.name], tasklet, retval.name, "0") - ast_utils.add_memlet_write(substate, self.name_mapping[sdfg][retval.name], tasklet, - retval.name + "_out", "0") + ast_utils.add_memlet_write(substate, self.name_mapping[sdfg][retval.name], tasklet, + retval.name + "_out", "0") + if isinstance(retval, ast_internal_classes.Array_Subscript_Node): + ast_utils.add_memlet_read(substate, self.name_mapping[sdfg][retval.name.name], tasklet, + retval.name.name, "0") + + ast_utils.add_memlet_write(substate, self.name_mapping[sdfg][retval.name.name], tasklet, + retval.name.name + "_out", "0") for i, j in zip(input_names, input_names_tasklet): memlet_range = self.get_memlet_range(sdfg, used_vars, i, j) ast_utils.add_memlet_read(substate, i, tasklet, j, memlet_range) for i, j, k in zip(output_names, output_names_tasklet, output_names_changed): - memlet_range = self.get_memlet_range(sdfg, used_vars, i, j) ast_utils.add_memlet_write(substate, i, tasklet, k, memlet_range) setattr(tasklet, "code", CodeBlock(text, lang.Python)) - def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG): """ This function translates a variable declaration statement to an access node on the sdfg :param node: The node to translate @@ -998,38 +2309,171 @@ def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG, c :note This function is the top level of the declaration, most implementation is in vardecl2sdfg """ for i in node.vardecl: - self.translate(i, sdfg, cfg) + self.translate(i, sdfg) - def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): + def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): """ This function translates a variable declaration to an access node on the sdfg :param node: The node to translate :param sdfg: The sdfg to attach the access node to """ - #if the sdfg is the toplevel-sdfg, the variable is a global variable - transient = True + if node.name == "modname": return + + # if the sdfg is the toplevel-sdfg, the variable is a global variable + is_arg = False + if isinstance(node.parent, + (ast_internal_classes.Subroutine_Subprogram_Node, ast_internal_classes.Function_Subprogram_Node)): + if hasattr(node.parent, "args"): + for i in node.parent.args: + name = ast_utils.get_name(i) + if name == node.name: + is_arg = True + if self.local_not_transient_because_assign.get(sdfg.name) is not None: + if name in self.local_not_transient_because_assign[sdfg.name]: + is_arg = False + break + + # if this is a variable declared in the module, + # then we will not add it unless it is used by the functions. + # It would be sufficient to check the main entry function, + # since it must pass this variable through call + # to other functions. + # However, I am not completely sure how to determine which function is the main one. + # + # we ignore the variable that is not used at all in all functions + # this is a module variaable that can be removed + if not is_arg: + if self.subroutine_used_names is not None: + + if node.name not in self.subroutine_used_names: + print( + f"Ignoring module variable {node.name} because it is not used in the the top level subroutine") + return + + if is_arg: + transient = False + else: + transient = self.transient_mode # find the type datatype = self.get_dace_type(node.type) - if hasattr(node, "alloc"): - if node.alloc: - self.unallocated_arrays.append([node.name, datatype, sdfg, transient]) - return + # if hasattr(node, "alloc"): + # if node.alloc: + # self.unallocated_arrays.append([node.name, datatype, sdfg, transient]) + # return # get the dimensions + # print(node.name) if node.sizes is not None: sizes = [] offset = [] - offset_value = -1 + actual_offsets = [] + offset_value = 0 if self.normalize_offsets else -1 for i in node.sizes: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping) - text = tw.write_code(i) - sizes.append(sym.pystr_to_symbolic(text)) - offset.append(offset_value) + stuff = [ii for ii in ast_transforms.mywalk(i) if isinstance(ii, ast_internal_classes.Data_Ref_Node)] + if len(stuff) > 0: + count = self.count_of_struct_symbols_lifted + sdfg.add_symbol("tmp_struct_symbol_" + str(count), dtypes.int32) + symname = "tmp_struct_symbol_" + str(count) + if sdfg.parent_sdfg is not None: + sdfg.parent_sdfg.add_symbol("tmp_struct_symbol_" + str(count), dtypes.int32) + sdfg.parent_nsdfg_node.symbol_mapping[ + "tmp_struct_symbol_" + str(count)] = "tmp_struct_symbol_" + str(count) + for edge in sdfg.parent.parent_graph.in_edges(sdfg.parent): + assign = ast_utils.ProcessedWriter(sdfg.parent_sdfg, self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) + edge.data.assignments["tmp_struct_symbol_" + str(count)] = assign + # print(edge) + else: + assign = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) + + sdfg.append_global_code(f"{dtypes.int32.ctype} {symname};\n") + sdfg.append_init_code( + "tmp_struct_symbol_" + str(count) + "=" + assign.replace(".", "->") + ";\n") + tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) + text = tw.write_code( + ast_internal_classes.Name_Node(name="tmp_struct_symbol_" + str(count), type="INTEGER", + line_number=node.line_number)) + sizes.append(sym.pystr_to_symbolic(text)) + actual_offset_value = node.offsets[node.sizes.index(i)] + if isinstance(actual_offset_value, ast_internal_classes.Array_Subscript_Node): + # print(node.name,actual_offset_value.name.name) + raise NotImplementedError("Array subscript in offset not implemented") + if isinstance(actual_offset_value, int): + actual_offset_value = ast_internal_classes.Int_Literal_Node(value=str(actual_offset_value)) + aotext = tw.write_code(actual_offset_value) + actual_offsets.append(str(sym.pystr_to_symbolic(aotext))) + + self.actual_offsets_per_sdfg[sdfg][node.name] = actual_offsets + # otext = tw.write_code(offset_value) + + # TODO: shouldn't this use node.offset?? + offset.append(offset_value) + self.count_of_struct_symbols_lifted += 1 + else: + tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) + text = tw.write_code(i) + actual_offset_value = node.offsets[node.sizes.index(i)] + if isinstance(actual_offset_value, int): + actual_offset_value = ast_internal_classes.Int_Literal_Node(value=str(actual_offset_value)) + aotext = tw.write_code(actual_offset_value) + actual_offsets.append(str(sym.pystr_to_symbolic(aotext))) + # otext = tw.write_code(offset_value) + sizes.append(sym.pystr_to_symbolic(text)) + offset.append(offset_value) + self.actual_offsets_per_sdfg[sdfg][node.name] = actual_offsets else: sizes = None # create and check name - if variable is already defined (function argument and defined in declaration part) simply stop if self.name_mapping[sdfg].get(node.name) is not None: + # here we must replace local placeholder sizes that have already made it to tasklets via size and ubound calls + if sizes is not None: + actual_sizes = sdfg.arrays[self.name_mapping[sdfg][node.name]].shape + # print(node.name,sdfg.name,self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)) + # print(sdfg.parent_sdfg.name,self.actual_offsets_per_sdfg[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg[sdfg][node.name])) + # print(sdfg.parent_sdfg.arrays.get(self.name_mapping[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)))) + if self.actual_offsets_per_sdfg[sdfg.parent_sdfg].get( + self.names_of_object_in_parent_sdfg[sdfg][node.name]) is not None: + actual_offsets = self.actual_offsets_per_sdfg[sdfg.parent_sdfg][ + self.names_of_object_in_parent_sdfg[sdfg][node.name]] + else: + actual_offsets = [1] * len(actual_sizes) + + index = 0 + for i in node.sizes: + if isinstance(i, ast_internal_classes.Name_Node): + if i.name.startswith("__f2dace_A"): + self.replace_names[i.name] = str(actual_sizes[index]) + # node.parent.execution_part=ast_transforms.RenameVar(i.name,str(actual_sizes[index])).visit(node.parent.execution_part) + index += 1 + index = 0 + for i in node.offsets: + if isinstance(i, ast_internal_classes.Name_Node): + if i.name.startswith("__f2dace_OA"): + self.replace_names[i.name] = str(actual_offsets[index]) + # node.parent.execution_part=ast_transforms.RenameVar(i.name,str(actual_offsets[index])).visit(node.parent.execution_part) + index += 1 + elif sizes is None: + if isinstance(datatype, Structure): + datatype_to_add = copy.deepcopy(datatype) + datatype_to_add.transient = transient + # if node.name=="p_nh": + # print("Adding local struct",self.name_mapping[sdfg][node.name],datatype_to_add) + if self.struct_views.get(sdfg) is None: + self.struct_views[sdfg] = {} + add_views_recursive(sdfg, node.name, datatype_to_add, self.struct_views[sdfg], + self.name_mapping[sdfg], self.registered_types, [], + self.actual_offsets_per_sdfg[sdfg], self.names_of_object_in_parent_sdfg[sdfg], + self.actual_offsets_per_sdfg[sdfg.parent_sdfg]) + return if node.name in sdfg.symbols: @@ -1038,15 +2482,52 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg self.name_mapping[sdfg][node.name] = sdfg._find_new_name(node.name) if sizes is None: - sdfg.add_scalar(self.name_mapping[sdfg][node.name], dtype=datatype, transient=transient) + if isinstance(datatype, Structure): + datatype_to_add = copy.deepcopy(datatype) + datatype_to_add.transient = transient + # if node.name=="p_nh": + # print("Adding local struct",self.name_mapping[sdfg][node.name],datatype_to_add) + sdfg.add_datadesc(self.name_mapping[sdfg][node.name], datatype_to_add) + if self.struct_views.get(sdfg) is None: + self.struct_views[sdfg] = {} + add_views_recursive(sdfg, node.name, datatype_to_add, self.struct_views[sdfg], self.name_mapping[sdfg], + self.registered_types, [], self.actual_offsets_per_sdfg[sdfg], {}, {}) + # for i in datatype_to_add.members: + # current_dtype=datatype_to_add.members[i].dtype + # for other_type in self.registered_types: + # if current_dtype.dtype==self.registered_types[other_type].dtype: + # other_type_obj=self.registered_types[other_type] + # for j in other_type_obj.members: + # sdfg.add_view(self.name_mapping[sdfg][node.name] + "_" + i +"_"+ j,other_type_obj.members[j].shape,other_type_obj.members[j].dtype) + # self.name_mapping[sdfg][node.name + "_" + i +"_"+ j] = self.name_mapping[sdfg][node.name] + "_" + i +"_"+ j + # self.struct_views[sdfg][self.name_mapping[sdfg][node.name] + "_" + i+"_"+ j]=[self.name_mapping[sdfg][node.name],j] + # sdfg.add_view(self.name_mapping[sdfg][node.name] + "_" + i,datatype_to_add.members[i].shape,datatype_to_add.members[i].dtype) + # self.name_mapping[sdfg][node.name + "_" + i] = self.name_mapping[sdfg][node.name] + "_" + i + # self.struct_views[sdfg][self.name_mapping[sdfg][node.name] + "_" + i]=[self.name_mapping[sdfg][node.name],i] + + else: + + sdfg.add_scalar(self.name_mapping[sdfg][node.name], dtype=datatype, transient=transient) else: strides = [dat._prod(sizes[:i]) for i in range(len(sizes))] - sdfg.add_array(self.name_mapping[sdfg][node.name], - shape=sizes, - dtype=datatype, - offset=offset, - strides=strides, - transient=transient) + + if isinstance(datatype, Structure): + datatype.transient = transient + arr_dtype = datatype[sizes] + arr_dtype.offset = [offset_value for _ in sizes] + container = dat.ContainerArray(stype=datatype, shape=sizes, offset=offset, transient=transient) + # print("Adding local container array",self.name_mapping[sdfg][node.name],sizes,datatype,offset,strides,transient) + sdfg.arrays[self.name_mapping[sdfg][node.name]] = container + # sdfg.add_datadesc(self.name_mapping[sdfg][node.name], arr_dtype) + + else: + # print("Adding local array",self.name_mapping[sdfg][node.name],sizes,datatype,offset,strides,transient) + sdfg.add_array(self.name_mapping[sdfg][node.name], + shape=sizes, + dtype=datatype, + offset=offset, + strides=strides, + transient=transient) self.all_array_names.append(self.name_mapping[sdfg][node.name]) if self.contexts.get(sdfg.name) is None: @@ -1054,16 +2535,28 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg if node.name not in self.contexts[sdfg.name].containers: self.contexts[sdfg.name].containers.append(node.name) - def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG, cfg: ControlFlowRegion): + if hasattr(node, "init") and node.init is not None: + self.translate( + ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=node.name, type=node.type), + op="=", rval=node.init, line_number=node.line_number), sdfg) + + def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG): + + self.last_loop_breaks[sdfg] = self.last_sdfg_states[sdfg] + sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) + + def continue2sdfg(self, node: ast_internal_classes.Continue_Node, sdfg: SDFG): + # + sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) + self.already_has_edge_back_continue[sdfg] = self.last_sdfg_states[sdfg] - self.last_loop_breaks[cfg] = self.last_sdfg_states[cfg] - cfg.add_edge(self.last_sdfg_states[cfg], self.last_loop_continues.get(cfg), InterstateEdge()) def create_ast_from_string( - source_string: str, - sdfg_name: str, - transform: bool = False, - normalize_offsets: bool = False + source_string: str, + sdfg_name: str, + transform: bool = False, + normalize_offsets: bool = False, + multiple_sdfgs: bool = False ): """ Creates an AST from a Fortran file in a string @@ -1076,16 +2569,32 @@ def create_ast_from_string( reader = fsr(source_string) ast = parser(reader) tables = SymbolTable - own_ast = ast_components.InternalFortranAst(ast, tables) + own_ast = ast_components.InternalFortranAst() program = own_ast.create_ast(ast) + structs_lister = ast_transforms.StructLister() + structs_lister.visit(program) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + + program.structures = ast_transforms.Structures(structs_lister.structs) + functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() functions_and_subroutines_builder.visit(program) - functions_and_subroutines = functions_and_subroutines_builder.nodes if transform: program = ast_transforms.functionStatementEliminator(program) - program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) + program = ast_transforms.CallToArray(functions_and_subroutines_builder).visit(program) program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) program = ast_transforms.ArrayToLoop(program).visit(program) @@ -1097,33 +2606,279 @@ def create_ast_from_string( program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) - return (program, own_ast) + program = ast_transforms.optionalArgsExpander(program) + + return program, own_ast + + +class ParseConfig: + def __init__(self, + main: Union[None, Path, str] = None, + sources: Union[None, List[Path], Dict[str, str]] = None, + includes: Union[None, List[Path], Dict[str, str]] = None, + entry_points: Union[None, SPEC, List[SPEC]] = None): + # Make the configs canonical, by processing the various types upfront. + if isinstance(main, Path): + main = main.read_text() + main = FortranStringReader(main) + if not sources: + sources: Dict[str, str] = {} + elif isinstance(sources, list): + sources: Dict[str, str] = {str(p): p.read_text() for p in sources} + if not includes: + includes: List[Path] = [] + if not entry_points: + entry_points = [] + elif isinstance(entry_points, tuple): + entry_points = [entry_points] + + self.main = main + self.sources = sources + self.includes = includes + self.entry_points = entry_points + + +def create_fparser_ast(cfg: ParseConfig) -> Program: + parser = ParserFactory().create(std="f2008") + ast = parser(cfg.main) + ast = recursive_ast_improver(ast, cfg.sources, cfg.includes, parser) + ast = lower_identifier_names(ast) + assert isinstance(ast, Program) + return ast + + +def create_internal_ast(cfg: ParseConfig) -> Tuple[ast_components.InternalFortranAst, FNode]: + ast = create_fparser_ast(cfg) + + ast = deconstruct_enums(ast) + ast = deconstruct_associations(ast) + ast = correct_for_function_calls(ast) + ast = deconstruct_procedure_calls(ast) + ast = deconstruct_interface_calls(ast) + + if not cfg.entry_points: + # Keep all the possible entry points. + entry_points = [ident_spec(ast_utils.singular(children_of_type(c, NAMED_STMTS_OF_INTEREST_TYPES))) + for c in ast.children if isinstance(c, ENTRY_POINT_OBJECT_TYPES)] + else: + eps = cfg.entry_points + if isinstance(eps, tuple): + eps = [eps] + ident_map = identifier_specs(ast) + entry_points = [ep for ep in eps if ep in ident_map] + ast = prune_unused_objects(ast, entry_points) + assert isinstance(ast, Program) + + iast = ast_components.InternalFortranAst() + prog = iast.create_ast(ast) + assert isinstance(prog, FNode) + iast.finalize_ast(prog) + return iast, prog + + +class SDFGConfig: + def __init__(self, + entry_points: Dict[str, Union[str, List[str]]], + normalize_offsets: bool = True, + multiple_sdfgs: bool = False): + for k in entry_points: + if isinstance(entry_points[k], str): + entry_points[k] = [entry_points[k]] + self.entry_points = entry_points + self.normalize_offsets = normalize_offsets + self.multiple_sdfgs = multiple_sdfgs + + +def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, program: FNode, cfg: SDFGConfig): + # Repeated! + # We need that to know in transformations what structures are used. + # The actual structure listing is repeated later to resolve cycles. + # Not sure if we can actually do it earlier. + + program = ast_transforms.functionStatementEliminator(program) + program = ast_transforms.StructConstructorToFunctionCall( + ast_transforms.FindFunctionAndSubroutines.from_node(program).names).visit(program) + program = ast_transforms.CallToArray(ast_transforms.FindFunctionAndSubroutines.from_node(program)).visit(program) + program = ast_transforms.CallExtractor().visit(program) + + program = ast_transforms.FunctionCallTransformer().visit(program) + program = ast_transforms.FunctionToSubroutineDefiner().visit(program) + program = ast_transforms.PointerRemoval().visit(program) + program = ast_transforms.ElementalFunctionExpander( + ast_transforms.FindFunctionAndSubroutines.from_node(program).names).visit(program) + for i in program.modules: + count = 0 + for j in i.function_definitions: + if isinstance(j, ast_internal_classes.Subroutine_Subprogram_Node): + i.subroutine_definitions.append(j) + count += 1 + if count != len(i.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + i.function_definitions = [] + program.function_definitions = [] + count = 0 + for i in program.function_definitions: + if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node): + program.subroutine_definitions.append(i) + count += 1 + if count != len(program.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + program.function_definitions = [] + program = ast_transforms.SignToIf().visit(program) + program = ast_transforms.ReplaceStructArgsLibraryNodes(program).visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + + for transformation in own_ast.fortran_intrinsics().transformations(): + transformation.initialize(program) + program = transformation.visit(program) + + program = ast_transforms.ArgumentExtractor(program).visit(program) + + program = ast_transforms.ForDeclarer().visit(program) + program = ast_transforms.IndexExtractor(program, cfg.normalize_offsets).visit(program) + program = ast_transforms.optionalArgsExpander(program) + program = ast_transforms.allocatableReplacer(program) + + structs_lister = ast_transforms.StructLister() + structs_lister.visit(program) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + # print(struct_deps) + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + cycles = nx.algorithms.cycles.simple_cycles(struct_dep_graph) + has_cycles = list(cycles) + cycles_we_cannot_ignore = [] + for cycle in has_cycles: + print(cycle) + for i in cycle: + is_pointer = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["pointing"] + point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] + # print(i,is_pointer) + if is_pointer: + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ + (cycle.index(i) + 1) % len(cycle)], point_name) + actually_used_pointer_node_finder.visit(program) + # print(actually_used_pointer_node_finder.nodes) + if len(actually_used_pointer_node_finder.nodes) == 0: + print("We can ignore this cycle") + program = ast_transforms.StructPointerEliminator(i, cycle[(cycle.index(i) + 1) % len(cycle)], + point_name).visit(program) + else: + cycles_we_cannot_ignore.append(cycle) + if len(cycles_we_cannot_ignore) > 0: + raise NameError("Structs have cyclic dependencies") + + # TODO: `ArgumentPruner` does not cleanly remove arguments (and it's not entirely clear that arguments must be + # pruned on the frontend in the first place), so disable until it is fixed. + # ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes).visit(program) + + gmap = {} + for ep, ep_spec in cfg.entry_points.items(): + # Find where to look for the entry point. + assert ep_spec + mod, pt = ep_spec[:-1], ep_spec[-1] + assert len(mod) <= 1, f"currently only one level of entry point search is supported, got: {ep_spec}" + ep_box = program # This is where we will search for our entry point. + if mod: + mod = mod[0] + mod = [m for m in program.modules if m.name.name == mod] + assert len(mod) <= 1, f"found multiple modules with the same name: {mod}" + if not mod: + # Could not even find the module, so skip. + continue + ep_box = mod[0] + + # Find the actual entry point. + fn = [f for f in ep_box.subroutine_definitions if f.name.name == pt] + if not mod and program.main_program and program.main_program.name.name.name == pt: + # The main program can be a valid entry point, so include that when appropriate. + fn.append(program.main_program) + assert len(fn) <= 1, f"found multiple subroutines with the same name {ep}" + if not fn: + continue + fn = fn[0] + + # Do the actual translation. + ast2sdfg = AST_translator(__file__, multiple_sdfgs=cfg.multiple_sdfgs, startpoint=fn, toplevel_subroutine=None, + normalize_offsets=cfg.normalize_offsets, do_not_make_internal_variables_argument=True) + g = SDFG(ep) + ast2sdfg.functions_and_subroutines = ast_transforms.FindFunctionAndSubroutines.from_node(program).names + ast2sdfg.structures = program.structures + ast2sdfg.placeholders = program.placeholders + ast2sdfg.placeholders_offsets = program.placeholders_offsets + ast2sdfg.actual_offsets_per_sdfg[g] = {} + ast2sdfg.top_level = program + ast2sdfg.globalsdfg = g + ast2sdfg.translate(program, g) + g.apply_transformations(IntrinsicSDFGTransformation) + g.expand_library_nodes() + gmap[ep] = g + + return gmap + def create_sdfg_from_string( - source_string: str, - sdfg_name: str, - normalize_offsets: bool = False, - use_explicit_cf: bool = False + source_string: str, + sdfg_name: str, + normalize_offsets: bool = True, + multiple_sdfgs: bool = False, + sources: List[str] = None, ): """ Creates an SDFG from a fortran file in a string :param source_string: The fortran file as a string :param sdfg_name: The name to be given to the resulting SDFG :return: The resulting SDFG - + """ - parser = pf().create(std="f2008") - reader = fsr(source_string) - ast = parser(reader) - tables = SymbolTable - own_ast = ast_components.InternalFortranAst(ast, tables) - program = own_ast.create_ast(ast) + cfg = ParseConfig(main=source_string, sources=sources) + own_ast, program = create_internal_ast(cfg) + + # Repeated! + # We need that to know in transformations what structures are used. + # The actual structure listing is repeated later to resolve cycles. + # Not sure if we can actually do it earlier. + functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() functions_and_subroutines_builder.visit(program) - own_ast.functions_and_subroutines = functions_and_subroutines_builder.nodes program = ast_transforms.functionStatementEliminator(program) - program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) + program = ast_transforms.StructConstructorToFunctionCall(functions_and_subroutines_builder.names).visit(program) + program = ast_transforms.CallToArray(functions_and_subroutines_builder).visit(program) + program = ast_transforms.IfConditionExtractor().visit(program) program = ast_transforms.CallExtractor().visit(program) + + + program = ast_transforms.FunctionCallTransformer().visit(program) + program = ast_transforms.FunctionToSubroutineDefiner().visit(program) + program = ast_transforms.PointerRemoval().visit(program) + program = ast_transforms.ElementalFunctionExpander(functions_and_subroutines_builder.names).visit(program) + for i in program.modules: + count = 0 + for j in i.function_definitions: + if isinstance(j, ast_internal_classes.Subroutine_Subprogram_Node): + i.subroutine_definitions.append(j) + count += 1 + if count != len(i.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + i.function_definitions = [] + program.function_definitions = [] + count = 0 + for i in program.function_definitions: + if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node): + program.subroutine_definitions.append(i) + count += 1 + if count != len(program.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + program.function_definitions = [] program = ast_transforms.SignToIf().visit(program) program = ast_transforms.ArrayToLoop(program).visit(program) @@ -1131,28 +2886,82 @@ def create_sdfg_from_string( transformation.initialize(program) program = transformation.visit(program) + program = ast_transforms.ArgumentExtractor(program).visit(program) + program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) + program = ast_transforms.optionalArgsExpander(program) + structs_lister = ast_transforms.StructLister() + structs_lister.visit(program) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + # print(struct_deps) + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + cycles = nx.algorithms.cycles.simple_cycles(struct_dep_graph) + has_cycles = list(cycles) + cycles_we_cannot_ignore = [] + for cycle in has_cycles: + print(cycle) + for i in cycle: + is_pointer = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["pointing"] + point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] + # print(i,is_pointer) + if is_pointer: + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ + (cycle.index(i) + 1) % len(cycle)], point_name) + actually_used_pointer_node_finder.visit(program) + # print(actually_used_pointer_node_finder.nodes) + if len(actually_used_pointer_node_finder.nodes) == 0: + print("We can ignore this cycle") + program = ast_transforms.StructPointerEliminator(i, cycle[(cycle.index(i) + 1) % len(cycle)], + point_name).visit(program) + else: + cycles_we_cannot_ignore.append(cycle) + if len(cycles_we_cannot_ignore) > 0: + raise NameError("Structs have cyclic dependencies") + + # program = + # ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes).visit(program) + + ast2sdfg = AST_translator(__file__, multiple_sdfgs=multiple_sdfgs, toplevel_subroutine=sdfg_name, + normalize_offsets=normalize_offsets) sdfg = SDFG(sdfg_name) + ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names + ast2sdfg.structures = program.structures + ast2sdfg.placeholders = program.placeholders + ast2sdfg.placeholders_offsets = program.placeholders_offsets + ast2sdfg.actual_offsets_per_sdfg[sdfg] = {} ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg ast2sdfg.translate(program, sdfg) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): - if 'test_function' in node.sdfg.name: - sdfg = node.sdfg - break + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None sdfg.reset_cfg_list() - sdfg.using_explicit_control_flow = use_explicit_cf + + sdfg.apply_transformations(IntrinsicSDFGTransformation) + sdfg.expand_library_nodes() + return sdfg -def create_sdfg_from_fortran_file(source_string: str, use_explicit_cf: bool = False): +def create_sdfg_from_fortran_file(source_string: str): """ Creates an SDFG from a fortran file :param source_string: The fortran file name @@ -1163,28 +2972,665 @@ def create_sdfg_from_fortran_file(source_string: str, use_explicit_cf: bool = Fa reader = ffr(source_string) ast = parser(reader) tables = SymbolTable - own_ast = ast_components.InternalFortranAst(ast, tables) + own_ast = ast_components.InternalFortranAst() program = own_ast.create_ast(ast) functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() functions_and_subroutines_builder.visit(program) - own_ast.functions_and_subroutines = functions_and_subroutines_builder.nodes + own_ast.functions_and_subroutines = functions_and_subroutines_builder.names program = ast_transforms.functionStatementEliminator(program) - program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) + program = ast_transforms.CallToArray(functions_and_subroutines_builder).visit(program) program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) - program = ast_transforms.ArrayToLoop(program).visit(program) - - for transformation in own_ast.fortran_intrinsics(): - transformation.initialize(program) - program = transformation.visit(program) - + program = ast_transforms.ArrayToLoop().visit(program) + program = ast_transforms.SumToLoop().visit(program) program = ast_transforms.ForDeclarer().visit(program) - program = ast_transforms.IndexExtractor(program).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) + program = ast_transforms.IndexExtractor().visit(program) + program = ast_transforms.optionalArgsExpander(program) + ast2sdfg = AST_translator(__file__) sdfg = SDFG(source_string) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg ast2sdfg.translate(program, sdfg) + sdfg.apply_transformations(IntrinsicSDFGTransformation) + sdfg.expand_library_nodes() - sdfg.using_explicit_control_flow = use_explicit_cf return sdfg + + +def compute_dep_graph(ast: Program, start_point: Union[str, List[str]]) -> nx.DiGraph: + """ + Compute a dependency graph among all the top level objects in the program. + """ + if isinstance(start_point, str): + start_point = [start_point] + + dep_graph = nx.DiGraph() + exclude = set() + to_process = start_point + while to_process: + item_name, to_process = to_process[0], to_process[1:] + item = ast_utils.atmost_one(c for c in ast.children if find_name_of_node(c) == item_name) + if not item: + print(f"Could not find: {item}") + continue + + fandsl = ast_utils.FunctionSubroutineLister() + fandsl.get_functions_and_subroutines(item) + dep_graph.add_node(item_name, info_list=fandsl) + + used_modules, objects_in_modules = ast_utils.get_used_modules(item) + for mod in used_modules: + if mod not in dep_graph.nodes: + dep_graph.add_node(mod) + obj_list = [] + if dep_graph.has_edge(item_name, mod): + edge = dep_graph.get_edge_data(item_name, mod) + if 'obj_list' in edge: + obj_list = edge.get('obj_list') + assert isinstance(obj_list, list) + if mod in objects_in_modules: + ast_utils.extend_with_new_items_from(obj_list, objects_in_modules[mod]) + dep_graph.add_edge(item_name, mod, obj_list=obj_list) + if mod not in exclude: + to_process.append(mod) + exclude.add(mod) + + return dep_graph + + +def recursive_ast_improver(ast: Program, source_list: Dict[str, str], include_list, parser): + exclude = set() + + NAME_REPLACEMENTS = { + 'mo_restart_nml_and_att': 'mo_restart_nmls_and_atts', + 'yomhook': 'yomhook_dummy', + } + + def _recursive_ast_improver(_ast: Base): + defined_modules = ast_utils.get_defined_modules(_ast) + used_modules, objects_in_modules = ast_utils.get_used_modules(_ast) + + modules_to_parse = [mod for mod in used_modules if mod not in chain(defined_modules, exclude)] + added_modules = [] + for mod in modules_to_parse: + name = mod.lower() + if name in NAME_REPLACEMENTS: + name = NAME_REPLACEMENTS[name] + + mod_file = [srcf for srcf in source_list if os.path.basename(srcf).lower() == f"{name}.f90"] + assert len(mod_file) <= 1, f"Found multiple files for the same module `{mod}`: {mod_file}" + if not mod_file: + print(f"Ignoring error: cannot find a file for `{mod}`") + continue + mod_file = mod_file[0] + + reader = fsr(source_list[mod_file], include_dirs=include_list) + try: + next_ast = parser(reader) + except Exception as e: + raise RuntimeError(f"{mod_file} could not be parsed: {e}") from e + + _recursive_ast_improver(next_ast) + + for c in reversed(next_ast.children): + if c in added_modules: + added_modules.remove(c) + added_modules.insert(0, c) + c_stmt = c.children[0] + c_name = ast_utils.singular(ast_utils.children_of_type(c_stmt, Name)).string + exclude.add(c_name) + + for mod in reversed(added_modules): + if mod not in _ast.children: + _ast.children.append(mod) + + _recursive_ast_improver(ast) + + # Add all the free-floating subprograms from other source files in case we missed them. + ast = collect_floating_subprograms(ast, source_list, include_list, parser) + # Sort the modules in the order of their dependency. + ast = sort_modules(ast) + + return ast + + +def collect_floating_subprograms(ast: Program, source_list: Dict[str, str], include_list, parser) -> Program: + known_names: Set[str] = {nm.string for nm in walk(ast, Name)} + + known_floaters: Set[str] = set() + for esp in ast.children: + name = find_name_of_node(esp) + if name: + known_floaters.add(name) + + known_sub_asts: Dict[str, Program] = {} + for src, content in source_list.items(): + + # TODO: Should be fixed in FParser. + # FParser cannot handle `convert=...` argument in the `open()` statement. + content = content.replace(',convert="big_endian"', '') + + reader = fsr(content, include_dirs=include_list) + try: + sub_ast = parser(reader) + except Exception as e: + print(f"Ignoring {src} due to error: {e}") + continue + known_sub_asts[src] = sub_ast + + # Since the order is not topological, we need to incrementally find more connected floating subprograms. + changed = True + while changed: + changed = False + new_floaters = [] + for src, sub_ast in known_sub_asts.items(): + # Find all the new floating subprograms that are known to be needed so far. + for esp in sub_ast.children: + name = find_name_of_node(esp) + if name and name in known_names and name not in known_floaters: + # We have found a new floating subprogram that's needed. + known_floaters.add(name) + known_names.update({nm.string for nm in walk(esp, Name)}) + new_floaters.append(esp) + if new_floaters: + # Append the new floating subprograms to our main AST. + append_children(ast, new_floaters) + changed = True + return ast + + +def name_and_rename_dict_creator(parse_order: list,dep_graph:nx.DiGraph)->Tuple[Dict[str, List[str]], Dict[str, Dict[str, str]]]: + name_dict = {} + rename_dict = {} + for i in parse_order: + local_rename_dict = {} + edges = list(dep_graph.in_edges(i)) + names = [] + for j in edges: + list_dict = dep_graph.get_edge_data(j[0], j[1]) + if (list_dict['obj_list'] is not None): + for k in list_dict['obj_list']: + if not k.__class__.__name__ == "Name": + if k.__class__.__name__ == "Rename": + if k.children[2].string not in names: + names.append(k.children[2].string) + local_rename_dict[k.children[2].string] = k.children[1].string + # print("Assumption failed: Object list contains non-name node") + else: + if k.string not in names: + names.append(k.string) + rename_dict[i] = local_rename_dict + name_dict[i] = names + return name_dict, rename_dict + + +@dataclass +class FindUsedFunctionsConfig: + root: str + needed_functions: List[str] + skip_functions: List[str] + + +def create_sdfg_from_fortran_file_with_options( + cfg: ParseConfig, + ast: Program, + sdfgs_dir, + subroutine_name: Optional[str] = None, + normalize_offsets: bool = True, + propagation_info=None, + enum_propagator_files: Optional[List[str]] = None, + enum_propagator_ast=None, + used_functions_config: Optional[FindUsedFunctionsConfig] = None, + already_parsed_ast=False +): + """ + Creates an SDFG from a fortran file + :param source_string: The fortran file name + :return: The resulting SDFG + + """ + if not already_parsed_ast: + ast = deconstruct_enums(ast) + ast = deconstruct_associations(ast) + ast = remove_access_statements(ast) + ast = correct_for_function_calls(ast) + ast = deconstruct_procedure_calls(ast) + ast = deconstruct_interface_calls(ast) + ast = const_eval_nodes(ast) + ast = prune_branches(ast) + ast = prune_unused_objects(ast, cfg.entry_points) + #ast = assign_globally_unique_subprogram_names(ast, {('radiation_interface', 'radiation')}) + #ast = assign_globally_unique_variable_names(ast, {'config'}) + ast = consolidate_uses(ast) + else: + ast = correct_for_function_calls(ast) + + dep_graph = compute_dep_graph(ast, 'radiation_interface') + parse_order = list(reversed(list(nx.topological_sort(dep_graph)))) + + what_to_parse_list = {} + name_dict, rename_dict = name_and_rename_dict_creator(parse_order, dep_graph) + + tables = SymbolTable + partial_ast = ast_components.InternalFortranAst() + partial_modules = {} + partial_ast.symbols["c_int"] = ast_internal_classes.Int_Literal_Node(value=4) + partial_ast.symbols["c_int8_t"] = ast_internal_classes.Int_Literal_Node(value=1) + partial_ast.symbols["c_int64_t"] = ast_internal_classes.Int_Literal_Node(value=8) + partial_ast.symbols["c_int32_t"] = ast_internal_classes.Int_Literal_Node(value=4) + partial_ast.symbols["c_size_t"] = ast_internal_classes.Int_Literal_Node(value=4) + partial_ast.symbols["c_long"] = ast_internal_classes.Int_Literal_Node(value=8) + partial_ast.symbols["c_signed_char"] = ast_internal_classes.Int_Literal_Node(value=1) + partial_ast.symbols["c_char"] = ast_internal_classes.Int_Literal_Node(value=1) + partial_ast.symbols["c_null_char"] = ast_internal_classes.Int_Literal_Node(value=1) + functions_to_rename = {} + + # Why would you ever name a file differently than the module? Especially just one random file out of thousands??? + # asts["mo_restart_nml_and_att"]=asts["mo_restart_nmls_and_atts"] + partial_ast.to_parse_list = what_to_parse_list + asts = {find_name_of_stmt(m).lower(): m for m in walk(ast, Module_Stmt)} + for i in parse_order: + partial_ast.current_ast = i + + partial_ast.unsupported_fortran_syntax[i] = [] + if i in ["mtime", "ISO_C_BINDING", "iso_c_binding", "mo_cdi", "iso_fortran_env", "netcdf"]: + continue + + # try: + partial_module = partial_ast.create_ast(asts[i.lower()]) + partial_modules[partial_module.name.name] = partial_module + # except Exception as e: + # print("Module " + i + " could not be parsed ", partial_ast.unsupported_fortran_syntax[i]) + # print(e, type(e)) + # print(partial_ast.unsupported_fortran_syntax[i]) + # continue + tmp_rename = rename_dict[i] + for j in tmp_rename: + # print(j) + if partial_ast.symbols.get(j) is None: + # raise NameError("Symbol " + j + " not found in partial ast") + if functions_to_rename.get(i) is None: + functions_to_rename[i] = [j] + else: + functions_to_rename[i].append(j) + else: + partial_ast.symbols[tmp_rename[j]] = partial_ast.symbols[j] + + print("Parsed successfully module: ", i, " ", partial_ast.unsupported_fortran_syntax[i]) + # print(partial_ast.unsupported_fortran_syntax[i]) + # try: + partial_ast.current_ast = "top level" + + program = partial_ast.create_ast(ast) + program.module_declarations = ast_utils.parse_module_declarations(program) + # except: + + # print(" top level module could not be parsed ", partial_ast.unsupported_fortran_syntax["top level"]) + # print(partial_ast.unsupported_fortran_syntax["top level"]) + # return + + structs_lister = ast_transforms.StructLister() + structs_lister.visit(program) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + # print(struct_deps) + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + program = ast_transforms.PropagateEnums().visit(program) + program = ast_transforms.Flatten_Classes(structs_lister.structs).visit(program) + program.structures = ast_transforms.Structures(structs_lister.structs) + + functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() + functions_and_subroutines_builder.visit(program) + listnames = [i.name for i in functions_and_subroutines_builder.names] + for i in functions_and_subroutines_builder.iblocks: + if i not in listnames: + functions_and_subroutines_builder.names.append(ast_internal_classes.Name_Node(name=i, type="VOID")) + program.iblocks = functions_and_subroutines_builder.iblocks + partial_ast.functions_and_subroutines = functions_and_subroutines_builder.names + program = ast_transforms.functionStatementEliminator(program) + # program = ast_transforms.StructConstructorToFunctionCall(functions_and_subroutines_builder.names).visit(program) + # program = ast_transforms.CallToArray(functions_and_subroutines_builder, rename_dict).visit(program) + # program = ast_transforms.TypeInterference(program).visit(program) + # program = ast_transforms.ReplaceInterfaceBlocks(program, functions_and_subroutines_builder).visit(program) + + program = ast_transforms.IfConditionExtractor().visit(program) + + program = ast_transforms.TypeInference(program, assert_voids=False).visit(program) + program = ast_transforms.CallExtractor().visit(program) + program = ast_transforms.ArgumentExtractor(program).visit(program) + program = ast_transforms.FunctionCallTransformer().visit(program) + program = ast_transforms.FunctionToSubroutineDefiner().visit(program) + + program = ast_transforms.optionalArgsExpander(program) + # program = ast_transforms.ArgumentExtractor(program).visit(program) + + count = 0 + for i in program.function_definitions: + if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node): + program.subroutine_definitions.append(i) + partial_ast.functions_and_subroutines.append(i.name) + count += 1 + if count != len(program.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + for i in program.modules: + count = 0 + for j in i.function_definitions: + if isinstance(j, ast_internal_classes.Subroutine_Subprogram_Node): + i.subroutine_definitions.append(j) + partial_ast.functions_and_subroutines.append(j.name) + count += 1 + if count != len(i.function_definitions): + raise NameError("Not all functions were transformed to subroutines") + i.function_definitions = [] + program.function_definitions = [] + + # time to trim the ast using the propagation info + # adding enums from radiotion config + if enum_propagator_files is not None: + parser = ParserFactory().create(std="f2008") + if enum_propagator_files is not None: + for file in enum_propagator_files: + config_ast = parser(ffr(file_candidate=file)) + partial_ast.create_ast(config_ast) + + radiation_config_internal_ast = partial_ast.create_ast(enum_propagator_ast) + enum_propagator = ast_transforms.PropagateEnums() + enum_propagator.visit(radiation_config_internal_ast) + + program = enum_propagator.generic_visit(program) + replacements = 1 + step = 1 + while replacements > 0: + program = enum_propagator.generic_visit(program) + prop = ast_transforms.AssignmentPropagator(propagation_info) + program = prop.visit(program) + replacements = prop.replacements + if_eval = ast_transforms.IfEvaluator() + program = if_eval.visit(program) + replacements += if_eval.replacements + print("Made " + str(replacements) + " replacements in step " + str(step) + " Prop: " + str( + prop.replacements) + " If: " + str(if_eval.replacements)) + step += 1 + + if used_functions_config is not None: + + unusedFunctionFinder = ast_transforms.FindUnusedFunctions(used_functions_config.root, parse_order) + unusedFunctionFinder.visit(program) + used_funcs = unusedFunctionFinder.used_names + current_list = used_funcs[used_functions_config.root] + current_list += used_functions_config.root + + needed = used_functions_config.needed_functions + + for i in reversed(parse_order): + for j in program.modules: + if j.name.name in used_functions_config.skip_functions: + continue + if j.name.name == i: + + for k in j.subroutine_definitions: + if k.name.name in current_list: + current_list += used_funcs[k.name.name] + needed.append([j.name.name, k.name.name]) + + for i in program.modules: + subroutines = [] + for j in needed: + if i.name.name == j[0]: + + for k in i.subroutine_definitions: + if k.name.name == j[1]: + subroutines.append(k) + i.subroutine_definitions = subroutines + + program = ast_transforms.SignToIf().visit(program) + program = ast_transforms.ReplaceStructArgsLibraryNodes(program).visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + program = ast_transforms.optionalArgsExpander(program) + program = ast_transforms.TypeInference(program, assert_voids=False).visit(program) + program = ast_transforms.ArgumentExtractor(program).visit(program) + program = ast_transforms.ReplaceStructArgsLibraryNodes(program).visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + print("Before intrinsics") + + prior_exception: Optional[NeedsTypeInferenceException] = None + for transformation in partial_ast.fortran_intrinsics().transformations(): + while True: + try: + transformation.initialize(program) + program = transformation.visit(program) + break + except NeedsTypeInferenceException as e: + + if prior_exception is not None: + if e.line_number == prior_exception.line_number and e.func_name == prior_exception.func_name: + print("Running additional type inference didn't help! VOID type in the same place.") + raise RuntimeError() + else: + prior_exception = e + print("Running additional type inference") + # FIXME: optimize func + program = ast_transforms.TypeInference(program, assert_voids=False).visit(program) + + print("After intrinsics") + + program = ast_transforms.TypeInference(program).visit(program) + program = ast_transforms.ReplaceInterfaceBlocks(program, functions_and_subroutines_builder).visit(program) + program = ast_transforms.optionalArgsExpander(program) + program = ast_transforms.ArgumentExtractor(program).visit(program) + program = ast_transforms.ElementalFunctionExpander(functions_and_subroutines_builder.names).visit(program) + # print("Before intrinsics") + # for transformation in partial_ast.fortran_intrinsics().transformations(): + # transformation.initialize(program) + # program = transformation.visit(program) + # print("After intrinsics") + program = ast_transforms.ForDeclarer().visit(program) + program = ast_transforms.PointerRemoval().visit(program) + program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) + structs_lister = ast_transforms.StructLister() + structs_lister.visit(program) + struct_dep_graph = nx.DiGraph() + for i, name in zip(structs_lister.structs, structs_lister.names): + if name not in struct_dep_graph.nodes: + struct_dep_graph.add_node(name) + struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + struct_deps_finder.visit(i) + struct_deps = struct_deps_finder.structs_used + for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + struct_deps_finder.pointer_names): + if j not in struct_dep_graph.nodes: + struct_dep_graph.add_node(j) + struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + cycles = nx.algorithms.cycles.simple_cycles(struct_dep_graph) + has_cycles = list(cycles) + cycles_we_cannot_ignore = [] + for cycle in has_cycles: + print(cycle) + for i in cycle: + is_pointer = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["pointing"] + point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] + # print(i,is_pointer) + if is_pointer: + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ + (cycle.index(i) + 1) % len(cycle)], point_name, structs_lister, struct_dep_graph, "simple") + actually_used_pointer_node_finder.visit(program) + # print(actually_used_pointer_node_finder.nodes) + if len(actually_used_pointer_node_finder.nodes) == 0: + print("We can ignore this cycle") + program = ast_transforms.StructPointerEliminator(i, cycle[(cycle.index(i) + 1) % len(cycle)], + point_name).visit(program) + else: + cycles_we_cannot_ignore.append(cycle) + if len(cycles_we_cannot_ignore) > 0: + raise NameError("Structs have cyclic dependencies") + # print("Deleting struct members...") + # struct_members_deleted = 0 + # for struct, name in zip(structs_lister.structs, structs_lister.names): + # struct_member_finder = ast_transforms.StructMemberLister() + # struct_member_finder.visit(struct) + # for member, is_pointer, point_name in zip(struct_member_finder.members, struct_member_finder.is_pointer, + # struct_member_finder.pointer_names): + # if is_pointer: + # actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(name, member, point_name, + # structs_lister, + # struct_dep_graph, "full") + # actually_used_pointer_node_finder.visit(program) + # found = False + # for i in actually_used_pointer_node_finder.nodes: + # nl = ast_transforms.FindNames() + # nl.visit(i) + # if point_name in nl.names: + # found = True + # break + # # print("Struct Name: ",name," Member Name: ",point_name, " Found: ", found) + # if not found: + # # print("We can delete this member") + # struct_members_deleted += 1 + # program = ast_transforms.StructPointerEliminator(name, member, point_name).visit(program) + # print("Deleted " + str(struct_members_deleted) + " struct members.") + # structs_lister = ast_transforms.StructLister() + # structs_lister.visit(program) + # struct_dep_graph = nx.DiGraph() + # for i, name in zip(structs_lister.structs, structs_lister.names): + # if name not in struct_dep_graph.nodes: + # struct_dep_graph.add_node(name) + # struct_deps_finder = ast_transforms.StructDependencyLister(structs_lister.names) + # struct_deps_finder.visit(i) + # struct_deps = struct_deps_finder.structs_used + # for j, pointing, point_name in zip(struct_deps, struct_deps_finder.is_pointer, + # struct_deps_finder.pointer_names): + # if j not in struct_dep_graph.nodes: + # struct_dep_graph.add_node(j) + # struct_dep_graph.add_edge(name, j, pointing=pointing, point_name=point_name) + + program.structures = ast_transforms.Structures(structs_lister.structs) + program.tables = partial_ast.symbols + program.placeholders = partial_ast.placeholders + program.placeholders_offsets = partial_ast.placeholders_offsets + program.functions_and_subroutines = partial_ast.functions_and_subroutines + unordered_modules = program.modules + + # arg_pruner = ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes) + # arg_pruner.visit(program) + + for j in program.subroutine_definitions: + + if subroutine_name is not None: + if not subroutine_name+"_decon" in j.name.name : + print("Skipping 1 ", j.name.name) + continue + + if j.execution_part is None: + continue + + print(f"Building SDFG {j.name.name}") + startpoint = j + ast2sdfg = AST_translator(__file__, multiple_sdfgs=False, startpoint=startpoint, sdfg_path=sdfgs_dir, + normalize_offsets=normalize_offsets) + sdfg = SDFG(j.name.name) + ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names + ast2sdfg.structures = program.structures + ast2sdfg.placeholders = program.placeholders + ast2sdfg.placeholders_offsets = program.placeholders_offsets + ast2sdfg.actual_offsets_per_sdfg[sdfg] = {} + ast2sdfg.top_level = program + ast2sdfg.globalsdfg = sdfg + + ast2sdfg.translate(program, sdfg) + + print(f'Saving SDFG {os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz")}') + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"), compress=True) + + sdfg.apply_transformations(IntrinsicSDFGTransformation) + + try: + sdfg.expand_library_nodes() + except: + print("Expansion failed for ", sdfg.name) + continue + + sdfg.validate() + print(f'Saving SDFG {os.path.join(sdfgs_dir, sdfg.name + "_validated_f.sdfgz")}') + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_validated_f.sdfgz"), compress=True) + + sdfg.simplify(verbose=True) + print(f'Saving SDFG {os.path.join(sdfgs_dir, sdfg.name + "_simplified_tr.sdfgz")}') + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_simplified_f.sdfgz"), compress=True) + + print(f'Compiling SDFG {os.path.join(sdfgs_dir, sdfg.name + "_simplifiedf.sdfgz")}') + sdfg.compile() + + for i in program.modules: + + # for path in source_list: + + # if path.lower().find(i.name.name.lower()) != -1: + # mypath = path + # break + + for j in i.subroutine_definitions: + + if subroutine_name is not None: + #special for radiation + if subroutine_name=='radiation': + if not 'radiation' == j.name.name : + print("Skipping ", j.name.name) + continue + + elif not subroutine_name in j.name.name : + print("Skipping ", j.name.name) + continue + + if j.execution_part is None: + continue + print(f"Building SDFG {j.name.name}") + startpoint = j + ast2sdfg = AST_translator( + __file__, + multiple_sdfgs=False, + startpoint=startpoint, + sdfg_path=sdfgs_dir, + # toplevel_subroutine_arg_names=arg_pruner.visited_funcs[toplevel_subroutine], + # subroutine_used_names=arg_pruner.used_in_all_functions, + normalize_offsets=normalize_offsets + ) + sdfg = SDFG(j.name.name) + ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names + ast2sdfg.structures = program.structures + ast2sdfg.placeholders = program.placeholders + ast2sdfg.placeholders_offsets = program.placeholders_offsets + ast2sdfg.actual_offsets_per_sdfg[sdfg] = {} + ast2sdfg.top_level = program + ast2sdfg.globalsdfg = sdfg + ast2sdfg.translate(program, sdfg) + + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"), compress=True) + + sdfg.apply_transformations(IntrinsicSDFGTransformation) + + try: + sdfg.expand_library_nodes() + except: + print("Expansion failed for ", sdfg.name) + continue + + sdfg.validate() + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_validated_f.sdfgz"), compress=True) + + sdfg.simplify(verbose=True) + print(f'Saving SDFG {os.path.join(sdfgs_dir, sdfg.name + "_simplified_tr.sdfgz")}') + sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_simplified_f.sdfgz"), compress=True) + + print(f'Compiling SDFG {os.path.join(sdfgs_dir, sdfg.name + "_simplifiedf.sdfgz")}') + sdfg.compile() + + # return sdfg diff --git a/dace/frontend/fortran/icon_config_propagation.py b/dace/frontend/fortran/icon_config_propagation.py new file mode 100644 index 0000000000..b6cfa48383 --- /dev/null +++ b/dace/frontend/fortran/icon_config_propagation.py @@ -0,0 +1,230 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +import os +import sys +from pathlib import Path + +from fparser.common.readfortran import FortranFileReader as ffr +from fparser.two.parser import ParserFactory as pf + +from dace.frontend.fortran.fortran_parser import ParseConfig, create_fparser_ast + +current = os.path.dirname(os.path.realpath(__file__)) +parent = os.path.dirname(current) +sys.path.append(parent) + +from dace.frontend.fortran import fortran_parser + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_internal_classes as ast_internal + + +def find_path_recursive(base_dir): + dirs = os.listdir(base_dir) + fortran_files = [] + for path in dirs: + if os.path.isdir(os.path.join(base_dir, path)): + fortran_files.extend(find_path_recursive(os.path.join(base_dir, path))) + if os.path.isfile(os.path.join(base_dir, path)) and (path.endswith(".F90") or path.endswith(".f90")): + fortran_files.append(os.path.join(base_dir, path)) + return fortran_files + + +def read_lines_between(file_path: str, start_str: str, end_str: str) -> list[str]: + lines_between = [] + with open(file_path, 'r') as file: + capture = False + for line in file: + if start_str in line: + capture = True + continue + if end_str in line: + if capture: + capture = False + break + if capture: + lines_between.append(line.strip()) + return lines_between[1:] + + +def parse_assignments(assignments: list[str]) -> list[tuple[str, str]]: + parsed_assignments = [] + for assignment in assignments: + # Remove comments + assignment = assignment.split('!')[0].strip() + if '=' in assignment: + a, b = assignment.split('=', 1) + parsed_assignments.append((a.strip(), b.strip())) + return parsed_assignments + + +if __name__ == "__main__": + + base_icon_path = sys.argv[1] + icon_file = sys.argv[2] + sdfgs_dir = sys.argv[3] + if len(sys.argv) > 4: + already_parsed_ast = sys.argv[4] + else: + already_parsed_ast = None + + base_dir_ecrad = f"{base_icon_path}/externals/ecrad" + base_dir_icon = f"{base_icon_path}/src" + fortran_files = find_path_recursive(base_dir_ecrad) + inc_list = [f"{base_icon_path}/externals/ecrad/include"] + + # Construct the primary ECRad AST. + parse_cfg = ParseConfig( + main=Path(f"{base_icon_path}/{icon_file}"), + sources=[Path(f) for f in fortran_files], + entry_points=[('radiation_interface', 'radiation')], + ) + #already_parsed_ast=None + if already_parsed_ast is None: + ecrad_ast = create_fparser_ast(parse_cfg) + already_parsed_ast_bool = False + else: + mini_parser=pf().create(std="f2008") + ecrad_ast = mini_parser(ffr(file_candidate=already_parsed_ast)) + already_parsed_ast_bool = True + + # ast_builder = ast_components.InternalFortranAst() + # parser = pf().create(std="f2008") + + # # Update configuration with user changes + # strings = read_lines_between( + # f"{base_icon_path}/run/exp.exclaim_ape_R2B09", + # "! radiation_nml: radiation scheme", + # "/" + # ) + # parsed_strings = parse_assignments(strings) + + # parkind_ast = parser(ffr(file_candidate=f"{base_icon_path}/src/shared/mo_kind.f90")) + # parkinds = ast_builder.create_ast(parkind_ast) + + # reader = ffr(file_candidate=f"{base_icon_path}/src/namelists/mo_radiation_nml.f90") + # namelist_ast = parser(reader) + # namelist_internal_ast = ast_builder.create_ast(namelist_ast) + # # this creates the initial list of assignments + # # this does not consider conditional control-flow + # # it excludes condiditions like "if" + # # + # # assignments are in form of (variable, value) + # # we try to replace instances of variable with "value" + # # value can be almost anything + # # + # lister = ast_transforms.AssignmentLister(parsed_strings) + + # replacements = 1 + # step = 1 + # # we replace if conditions iteratively until no more changes + # # are done to the source code + # while replacements > 0: + # lister.reset() + # lister.visit(namelist_internal_ast) + + # # Propagate assignments + # prop = ast_transforms.AssignmentPropagator(lister.simple_assignments) + # namelist_internal_ast = prop.visit(namelist_internal_ast) + # replacements = prop.replacements + + # # We try to evaluate if conditions. If we can evaluate to true/false, + # # then we replace the if condition with the exact path + # if_eval = ast_transforms.IfEvaluator() + # namelist_internal_ast = if_eval.visit(namelist_internal_ast) + # replacements += if_eval.replacements + # print("Made " + str(replacements) + " replacements in step " + str(step) + " Prop: " + str( + # prop.replacements) + " If: " + str(if_eval.replacements)) + # step += 1 + + # # adding enums from radiation config + # adiation_config_ast = parser( + # ffr(file_candidate=f"{base_icon_path}/src/configure_model/mo_radiation_config.f90") + # ) + # radiation_config_internal_ast = ast_builder.create_ast(adiation_config_ast) + # # replace long complex enum names with integers + # enum_propagator = ast_transforms.PropagateEnums() + # enum_propagator.visit(radiation_config_internal_ast) + + # # namelist_assignments.insert(0,("amd", "28.970")) + + # # Repeat the + # ecrad_init_ast = parser(ffr(file_candidate=f"{base_icon_path}/src/atm_phy_nwp/mo_nwp_ecrad_init.f90")) + # ecrad_internal_ast = ast_builder.create_ast(ecrad_init_ast) + # # clearing acc check + # ecrad_internal_ast.modules[0].subroutine_definitions.pop(1) + # ecrad_internal_ast = enum_propagator.generic_visit(ecrad_internal_ast) + # lister2 = ast_transforms.AssignmentLister(parsed_strings) + # replacements = 1 + # step = 1 + # while replacements > 0: + # lister2.reset() + # ecrad_internal_ast = enum_propagator.generic_visit(ecrad_internal_ast) + # lister2.visit(ecrad_internal_ast) + # prop = ast_transforms.AssignmentPropagator(lister2.simple_assignments + lister.simple_assignments) + # ecrad_internal_ast = prop.visit(ecrad_internal_ast) + # replacements = prop.replacements + # if_eval = ast_transforms.IfEvaluator() + # ecrad_internal_ast = if_eval.visit(ecrad_internal_ast) + # replacements += if_eval.replacements + # print("Made " + str(replacements) + " replacements in step " + str(step) + " Prop: " + str( + # prop.replacements) + " If: " + str(if_eval.replacements)) + # step += 1 + + # # TODO: a couple of manual replacements + # lister2.simple_assignments.append( + # (ast_internal.Data_Ref_Node(parent_ref=ast_internal.Name_Node(name="ecrad_conf"), + # part_ref=ast_internal.Name_Node(name="do_save_radiative_properties")), + # ast_internal.Bool_Literal_Node(value='False'))) + + # # this is defined internally in the program as "false" + # # We remove it to simplify the code + # lister2.simple_assignments.append( + # (ast_internal.Name_Node(name="lhook"), + # ast_internal.Bool_Literal_Node(value='False'))) + + # propagation_info = lister.simple_assignments + lister2.simple_assignments + + # # let's fix the propagation info for ECRAD + # for i in propagation_info: + # if isinstance(i[0], ast_internal.Data_Ref_Node): + # i[0].parent_ref.name = i[0].parent_ref.name.replace("ecrad_conf", "config") + + # radiation_config_ast = parser( + # ffr(file_candidate=f"{base_icon_path}/externals/ecrad/radiation/radiation_config.F90")) + + # enum_propagator_files = [ + # f"{base_icon_path}/src/shared/mo_kind.f90", + # f"{base_icon_path}/externals/ecrad/ifsaux/parkind1.F90", + # f"{base_icon_path}/externals/ecrad/ifsaux/ecradhook.F90" + # ] + + cfg = fortran_parser.FindUsedFunctionsConfig( + root='radiation', + needed_functions=[['radiation_interface', 'radiation']], + skip_functions=['radiation_monochromatic', 'radiation_cloudless_sw', + 'radiation_tripleclouds_sw', 'radiation_homogeneous_sw'] + ) + + # generate_propagation_info(propagation_info) + + # previous steps were used to generate the initial list of assignments for ECRAD + # this includes user config and internal enumerations of ICON + # the previous ASTs can be now disregarded + # we only keep the list of assignments and propagate it to ECRAD parsing. + print(f"{base_icon_path}/{icon_file}") + #already_parsed_ast_bool = False + fortran_parser.create_sdfg_from_fortran_file_with_options( + parse_cfg, + ecrad_ast, + sdfgs_dir=sdfgs_dir, + subroutine_name="radiation", + # subroutine_name="cloud_generator", + normalize_offsets=True, + #propagation_info=propagation_info, + #enum_propagator_ast=radiation_config_ast, + #enum_propagator_files=enum_propagator_files, + used_functions_config=cfg, + already_parsed_ast=already_parsed_ast_bool + ) diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index af44a8dfb5..d5023e71fc 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -1,16 +1,30 @@ - -from abc import abstractmethod import copy import math +import sys +from abc import abstractmethod from collections import namedtuple -from typing import Any, List, Optional, Set, Tuple, Type +from typing import Any, List, Optional, Tuple, Union from dace.frontend.fortran import ast_internal_classes +from dace.frontend.fortran.ast_transforms import NodeVisitor, NodeTransformer, ParentScopeAssigner, \ + ScopeVarsDeclarations, TypeInference, par_Decl_Range_Finder, mywalk from dace.frontend.fortran.ast_utils import fortrantypes2dacetypes -from dace.frontend.fortran.ast_transforms import NodeVisitor, NodeTransformer, ParentScopeAssigner, ScopeVarsDeclarations, par_Decl_Range_Finder, mywalk +from dace.libraries.blas.nodes.dot import dot_libnode +from dace.libraries.blas.nodes.gemm import gemm_libnode +from dace.libraries.standard.nodes import Transpose +from dace.sdfg import SDFGState, SDFG, nodes +from dace.sdfg.graph import OrderedDiGraph +from dace.transformation import transformation as xf FASTNode = Any +class NeedsTypeInferenceException(BaseException): + + def __init__(self, func_name, line_number): + + self.line_number = line_number + self.func_name = func_name + class IntrinsicTransformation: @staticmethod @@ -20,34 +34,96 @@ def replaced_name(func_name: str) -> str: @staticmethod @abstractmethod - def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: + def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line, + symbols: list) -> ast_internal_classes.FNode: pass @staticmethod def has_transformation() -> bool: return False + class IntrinsicNodeTransformer(NodeTransformer): def initialize(self, ast): # We need to rerun the assignment because transformations could have created # new AST nodes ParentScopeAssigner().visit(ast) - self.scope_vars = ScopeVarsDeclarations() + self.scope_vars = ScopeVarsDeclarations(ast) self.scope_vars.visit(ast) + self.ast = ast + + def _parse_struct_ref(self, node: ast_internal_classes.Data_Ref_Node) -> ast_internal_classes.FNode: + + # we assume starting from the top (left-most) data_ref_node + # for struct1 % struct2 % struct3 % var + # we find definition of struct1, then we iterate until we find the var + + struct_type = self.scope_vars.get_var(node.parent, node.parent_ref.name).type + struct_def = self.ast.structures.structures[struct_type] + cur_node = node + + while True: + cur_node = cur_node.part_ref + + if isinstance(cur_node, ast_internal_classes.Array_Subscript_Node): + struct_def = self.ast.structures.structures[struct_type] + return struct_def.vars[cur_node.name.name] + + elif isinstance(cur_node, ast_internal_classes.Name_Node): + struct_def = self.ast.structures.structures[struct_type] + return struct_def.vars[cur_node.name] + + elif isinstance(cur_node, ast_internal_classes.Data_Ref_Node): + struct_type = struct_def.vars[cur_node.parent_ref.name].type + struct_def = self.ast.structures.structures[struct_type] + + else: + raise NotImplementedError() + + def get_var_declaration(self, + parent: ast_internal_classes.FNode, + variable: Union[ + ast_internal_classes.Data_Ref_Node, ast_internal_classes.Name_Node, + ast_internal_classes.Array_Subscript_Node]): + if isinstance(variable, ast_internal_classes.Data_Ref_Node): + variable = self._parse_struct_ref(variable) + return variable + + assert isinstance(variable, (ast_internal_classes.Name_Node, ast_internal_classes.Array_Subscript_Node)) + if isinstance(variable, ast_internal_classes.Name_Node): + name = variable.name + elif isinstance(variable, ast_internal_classes.Array_Subscript_Node): + name = variable.name.name + + if self.scope_vars.contains_var(parent, name): + return self.scope_vars.get_var(parent, name) + elif name in self.ast.module_declarations: + return self.ast.module_declarations[name] + else: + raise RuntimeError(f"Couldn't find the declaration of variable {name} in function {parent.name.name}!") @staticmethod @abstractmethod - def func_name(self) -> str: + def func_name() -> str: pass -class DirectReplacement(IntrinsicTransformation): + # @staticmethod + # @abstractmethod + # def transformation_name(self) -> str: + # pass + +class DirectReplacement(IntrinsicTransformation): Replacement = namedtuple("Replacement", "function") Transformation = namedtuple("Transformation", "function") class ASTTransformation(IntrinsicNodeTransformer): + @staticmethod + def func_name() -> str: + return "direct_replacement" + def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): if not isinstance(binop_node.rval, ast_internal_classes.Call_Expr_Node): @@ -62,33 +138,32 @@ def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): replacement_rule = DirectReplacement.FUNCTIONS[func_name] if isinstance(replacement_rule, DirectReplacement.Transformation): - # FIXME: we do not have line number in binop? - binop_node.rval, input_type = replacement_rule.function(node, self.scope_vars, 0) #binop_node.line) - print(binop_node, binop_node.lval, binop_node.rval) + binop_node.rval, input_type = replacement_rule.function(self, node, 0) # binop_node.line) - # replace types of return variable - LHS of the binary operator var = binop_node.lval - if isinstance(var.name, ast_internal_classes.Name_Node): - name = var.name.name - else: - name = var.name - var_decl = self.scope_vars.get_var(var.parent, name) - var.type = input_type - var_decl.type = input_type - return binop_node + # replace types of return variable - LHS of the binary operator + # we only propagate that for the assignment + # we handle extracted call variables this way + # but we can also have different shapes, e.g., `maxval(something) > something_else` + # hence the check + if isinstance(var, (ast_internal_classes.Name_Node, ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)): + + var_decl = self.get_var_declaration(var.parent, var) + var_decl.type = input_type + var.type = input_type - #self.scope_vars.get_var(node.parent, arg.name). + return binop_node - def replace_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVarsDeclarations, line): + def replace_size(transformer: IntrinsicNodeTransformer, var: ast_internal_classes.Call_Expr_Node, line): if len(var.args) not in [1, 2]: - raise RuntimeError() + assert False, "Incorrect arguments to size!" # get variable declaration for the first argument - var_decl = scope_vars.get_var(var.parent, var.args[0].name) + var_decl = transformer.get_var_declaration(var.parent, var.args[0]) # one arg to SIZE? compute the total number of elements if len(var.args) == 1: @@ -98,15 +173,14 @@ def replace_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVars ret = ast_internal_classes.BinOp_Node( lval=var_decl.sizes[0], - rval=None, + rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), op="*" ) cur_node = ret for i in range(1, len(var_decl.sizes) - 1): - cur_node.rval = ast_internal_classes.BinOp_Node( lval=var_decl.sizes[i], - rval=None, + rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), op="*" ) cur_node = cur_node.rval @@ -120,42 +194,171 @@ def replace_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVars if not isinstance(rank, ast_internal_classes.Int_Literal_Node): raise NotImplementedError() value = int(rank.value) - return (var_decl.sizes[value-1], "INTEGER") + return (var_decl.sizes[value - 1], "INTEGER") + def _replace_lbound_ubound(func: str, transformer: IntrinsicNodeTransformer, + var: ast_internal_classes.Call_Expr_Node, line): - def replace_bit_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVarsDeclarations, line): + if len(var.args) not in [1, 2]: + assert False, "Incorrect arguments to lbound/ubound" + + # get variable declaration for the first argument + var_decl = transformer.get_var_declaration(var.parent, var.args[0]) + + # one arg to LBOUND/UBOUND? not needed currently + if len(var.args) == 1: + raise NotImplementedError() + + # two arguments? We return number of elements in a given rank + rank = var.args[1] + # we do not support symbolic argument to DIM - it must be a literal + if not isinstance(rank, ast_internal_classes.Int_Literal_Node): + raise NotImplementedError() + + rank_value = int(rank.value) + + is_assumed = isinstance(var_decl.offsets[rank_value - 1], ast_internal_classes.Name_Node) and var_decl.offsets[ + rank_value - 1].name.startswith("__f2dace_") + + if func == 'lbound': + + if is_assumed and not var_decl.alloc: + value = ast_internal_classes.Int_Literal_Node(value="1") + elif isinstance(var_decl.offsets[rank_value - 1], int): + value = ast_internal_classes.Int_Literal_Node(value=str(var_decl.offsets[rank_value - 1])) + else: + value = var_decl.offsets[rank_value - 1] + + else: + if isinstance(var_decl.sizes[rank_value - 1], ast_internal_classes.FNode): + size = var_decl.sizes[rank_value - 1] + else: + size = ast_internal_classes.Int_Literal_Node(value=var_decl.sizes[rank_value - 1]) + + if is_assumed and not var_decl.alloc: + value = size + else: + if isinstance(var_decl.offsets[rank_value - 1], ast_internal_classes.FNode): + offset = var_decl.offsets[rank_value - 1] + elif isinstance(var_decl.offsets[rank_value - 1], int): + offset = ast_internal_classes.Int_Literal_Node(value=str(var_decl.offsets[rank_value - 1])) + else: + offset = ast_internal_classes.Int_Literal_Node(value=var_decl.offsets[rank_value - 1]) + + value = ast_internal_classes.BinOp_Node( + op="+", + lval=size, + rval=ast_internal_classes.BinOp_Node( + op="-", + lval=offset, + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=line + ), + line_number=line + ) + + return (value, "INTEGER") + + def replace_lbound(transformer: IntrinsicNodeTransformer, var: ast_internal_classes.Call_Expr_Node, line): + return DirectReplacement._replace_lbound_ubound("lbound", transformer, var, line) + + def replace_ubound(transformer: IntrinsicNodeTransformer, var: ast_internal_classes.Call_Expr_Node, line): + return DirectReplacement._replace_lbound_ubound("ubound", transformer, var, line) + + def replace_bit_size(transformer: IntrinsicNodeTransformer, var: ast_internal_classes.Call_Expr_Node, line): if len(var.args) != 1: - raise RuntimeError() + assert False, "Incorrect arguments to bit_size" # get variable declaration for the first argument - var_decl = scope_vars.get_var(var.parent, var.args[0].name) + var_decl = transformer.get_var_declaration(var.parent, var.args[0]) dace_type = fortrantypes2dacetypes[var_decl.type] type_size = dace_type().itemsize * 8 return (ast_internal_classes.Int_Literal_Node(value=str(type_size)), "INTEGER") - - def replace_int_kind(args: ast_internal_classes.Arg_List_Node, line): + def replace_int_kind(args: ast_internal_classes.Arg_List_Node, line, symbols: list): + if isinstance(args.args[0], ast_internal_classes.Int_Literal_Node): + arg0 = args.args[0].value + elif isinstance(args.args[0], ast_internal_classes.Name_Node): + if args.args[0].name in symbols: + arg0 = symbols[args.args[0].name].value + else: + raise ValueError("Only symbols can be names in selector") + else: + raise ValueError("Only literals or symbols can be arguments in selector") return ast_internal_classes.Int_Literal_Node(value=str( - math.ceil((math.log2(math.pow(10, int(args.args[0].value))) + 1) / 8)), - line_number=line) - - def replace_real_kind(args: ast_internal_classes.Arg_List_Node, line): - if int(args.args[0].value) >= 9 or int(args.args[1].value) > 126: + math.ceil((math.log2(math.pow(10, int(arg0))) + 1) / 8)), + line_number=line) + + def replace_real_kind(args: ast_internal_classes.Arg_List_Node, line, symbols: list): + if isinstance(args.args[0], ast_internal_classes.Int_Literal_Node): + arg0 = args.args[0].value + elif isinstance(args.args[0], ast_internal_classes.Name_Node): + if args.args[0].name in symbols: + arg0 = symbols[args.args[0].name].value + else: + raise ValueError("Only symbols can be names in selector") + else: + raise ValueError("Only literals or symbols can be arguments in selector") + if len(args.args) == 2: + if isinstance(args.args[1], ast_internal_classes.Int_Literal_Node): + arg1 = args.args[1].value + elif isinstance(args.args[1], ast_internal_classes.Name_Node): + if args.args[1].name in symbols: + arg1 = symbols[args.args[1].name].value + else: + raise ValueError("Only symbols can be names in selector") + else: + raise ValueError("Only literals or symbols can be arguments in selector") + else: + arg1 = 0 + if int(arg0) >= 9 or int(arg1) > 126: return ast_internal_classes.Int_Literal_Node(value="8", line_number=line) - elif int(args.args[0].value) >= 3 or int(args.args[1].value) > 14: + elif int(arg0) >= 3 or int(arg1) > 14: return ast_internal_classes.Int_Literal_Node(value="4", line_number=line) else: return ast_internal_classes.Int_Literal_Node(value="2", line_number=line) + def replace_present(transformer: IntrinsicNodeTransformer, call: ast_internal_classes.Call_Expr_Node, line): + + assert len(call.args) == 1 + assert isinstance(call.args[0], ast_internal_classes.Name_Node) + + var_name = call.args[0].name + test_var_name = f'__f2dace_OPTIONAL_{var_name}' + + return (ast_internal_classes.Name_Node(name=test_var_name), "LOGICAL") + + def replace_allocated(transformer: IntrinsicNodeTransformer, call: ast_internal_classes.Call_Expr_Node, line): + + assert len(call.args) == 1 + assert isinstance(call.args[0], ast_internal_classes.Name_Node) + + var_name = call.args[0].name + test_var_name = f'__f2dace_ALLOCATED_{var_name}' + + return (ast_internal_classes.Name_Node(name=test_var_name), "LOGICAL") + + def replacement_epsilon(args: ast_internal_classes.Arg_List_Node, line, symbols: list): + + # assert len(args) == 1 + # assert isinstance(args[0], ast_internal_classes.Name_Node) + + ret_val = sys.float_info.epsilon + return ast_internal_classes.Real_Literal_Node(value=str(ret_val)) FUNCTIONS = { "SELECTED_INT_KIND": Replacement(replace_int_kind), "SELECTED_REAL_KIND": Replacement(replace_real_kind), + "EPSILON": Replacement(replacement_epsilon), "BIT_SIZE": Transformation(replace_bit_size), - "SIZE": Transformation(replace_size) + "SIZE": Transformation(replace_size), + "LBOUND": Transformation(replace_lbound), + "UBOUND": Transformation(replace_ubound), + "PRESENT": Transformation(replace_present), + "ALLOCATED": Transformation(replace_allocated) } @staticmethod @@ -173,7 +376,7 @@ def replacable_name(func_name: str) -> bool: @staticmethod def replace_name(func_name: str) -> str: - #return ast_internal_classes.Name_Node(name=DirectReplacement.FUNCTIONS[func_name][0]) + # return ast_internal_classes.Name_Node(name=DirectReplacement.FUNCTIONS[func_name][0]) return ast_internal_classes.Name_Node(name=f'__dace_{func_name}') @staticmethod @@ -184,11 +387,11 @@ def replacable(func_name: str) -> bool: return False @staticmethod - def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: - + def replace(func_name: str, args: ast_internal_classes.Arg_List_Node, line, symbols: list) \ + -> ast_internal_classes.FNode: # Here we already have __dace_func fname = func_name.split('__dace_')[1] - return DirectReplacement.FUNCTIONS[fname].function(args, line) + return DirectReplacement.FUNCTIONS[fname].function(args, line, symbols) def has_transformation(fname: str) -> bool: return isinstance(DirectReplacement.FUNCTIONS[fname], DirectReplacement.Transformation) @@ -197,8 +400,8 @@ def has_transformation(fname: str) -> bool: def get_transformation() -> IntrinsicNodeTransformer: return DirectReplacement.ASTTransformation() -class LoopBasedReplacement: +class LoopBasedReplacement: INTRINSIC_TO_DACE = { "SUM": "__dace_sum", "PRODUCT": "__dace_product", @@ -218,11 +421,12 @@ def replaced_name(func_name: str) -> str: def has_transformation() -> bool: return True -class LoopBasedReplacementVisitor(NodeVisitor): +class LoopBasedReplacementVisitor(NodeVisitor): """ Finds all intrinsic operations that have to be transformed to loops in the AST """ + def __init__(self, func_name: str): self._func_name = func_name self.nodes: List[ast_internal_classes.FNode] = [] @@ -245,11 +449,12 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return -class LoopBasedReplacementTransformation(IntrinsicNodeTransformer): +class LoopBasedReplacementTransformation(IntrinsicNodeTransformer): """ Transforms the AST by removing intrinsic call and replacing it with loops """ + def __init__(self): self.count = 0 self.rvals = [] @@ -263,7 +468,8 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): pass @abstractmethod - def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, + new_func_body: List[ast_internal_classes.FNode]): pass @abstractmethod @@ -292,17 +498,30 @@ def _skip_result_assignment(self): def _update_result_type(self, var: ast_internal_classes.Name_Node): pass - def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, arg: ast_internal_classes.FNode) -> ast_internal_classes.Array_Subscript_Node: + def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, + arg: ast_internal_classes.FNode, dims_count: Optional[int] = -1 + ) -> ast_internal_classes.Array_Subscript_Node: # supports syntax func(arr) - if isinstance(arg, ast_internal_classes.Name_Node): - array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent) - array_node.name = arg - + if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Data_Ref_Node)): # If we access SUM(arr) where arr has many dimensions, # We need to create a ParDecl_Node for each dimension - dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes) - array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims + # array_sizes = self.scope_vars.get_var(node.parent, arg.name).sizes + array_sizes = self.get_var_declaration(node.parent, arg).sizes + if array_sizes is None: + + if dims_count != -1: + # for destination array, sizes might be unknown when we use arg extractor + # in that situation, we look at the size of the first argument + dims = dims_count + else: + return None + else: + dims = len(array_sizes) + + array_node = ast_internal_classes.Array_Subscript_Node( + name=arg, parent=arg.parent, type='VOID', + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims) return array_node @@ -310,11 +529,14 @@ def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, arg: ast_ if isinstance(arg, ast_internal_classes.Array_Subscript_Node): return arg - def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_internal_classes.BinOp_Node) -> Tuple[ - ast_internal_classes.Array_Subscript_Node, - Optional[ast_internal_classes.Array_Subscript_Node], - ast_internal_classes.BinOp_Node - ]: + return None + + def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_internal_classes.BinOp_Node) -> \ + Tuple[ + ast_internal_classes.Array_Subscript_Node, + Optional[ast_internal_classes.Array_Subscript_Node], + ast_internal_classes.BinOp_Node + ]: """ Supports passing binary operations as an input to function. @@ -333,7 +555,7 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i """ if not isinstance(arg, ast_internal_classes.BinOp_Node): - return False + return (None, None, None) first_array = self._parse_array(node, arg.lval) second_array = self._parse_array(node, arg.rval) @@ -372,7 +594,8 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i return (first_array, second_array, cond) - def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_internal_classes.Array_Subscript_Node, loop_ranges_main: list, loop_ranges_array: list): + def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_internal_classes.Array_Subscript_Node, + loop_ranges_main: list, loop_ranges_array: list): """ When given a binary operator with arrays as an argument to the intrinsic, @@ -391,16 +614,30 @@ def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_inte idx_var = array.indices[i] start_loop = loop_ranges_main[i][0] end_loop = loop_ranges_array[i][0] + - difference = int(end_loop.value) - int(start_loop.value) - if difference != 0: - new_index = ast_internal_classes.BinOp_Node( - lval=idx_var, - op="+", - rval=ast_internal_classes.Int_Literal_Node(value=str(difference)), - line_number=node.line_number - ) - array.indices[i] = new_index + difference = ast_internal_classes.BinOp_Node( + lval=end_loop, + op="-", + rval=start_loop, + line_number=node.line_number + ) + new_index = ast_internal_classes.BinOp_Node( + lval=idx_var, + op="+", + rval=difference, + line_number=node.line_number + ) + array.indices[i] = new_index + #difference = int(end_loop.value) - int(start_loop.value) + #if difference != 0: + # new_index = ast_internal_classes.BinOp_Node( + # lval=idx_var, + # op="+", + # rval=ast_internal_classes.Int_Literal_Node(value=str(difference)), + # line_number=node.line_number + # ) + # array.indices[i] = new_index def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): @@ -476,6 +713,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No self.count = self.count + range_index return ast_internal_classes.Execution_Part_Node(execution=newbody) + class SumProduct(LoopBasedReplacementTransformation): def _initialize(self): @@ -487,9 +725,9 @@ def _update_result_type(self, var: ast_internal_classes.Name_Node): """ For both SUM and PRODUCT, the result type depends on the input variable. """ - input_type = self.scope_vars.get_var(var.parent, self.argument_variable.name.name) + input_type = self.get_var_declaration(var.parent, self.argument_variable) - var_decl = self.scope_vars.get_var(var.parent, var.name) + var_decl = self.get_var_declaration(var.parent, var) var.type = input_type.type var_decl.type = input_type.type @@ -504,15 +742,16 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): else: raise NotImplementedError("We do not support non-array arguments for SUM/PRODUCT") - - def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, + new_func_body: List[ast_internal_classes.FNode]): if len(self.rvals) != 1: raise NotImplementedError("Only one array can be summed") self.argument_variable = self.rvals[0] - par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, True) def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: @@ -539,7 +778,6 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ class Sum(LoopBasedReplacement): - """ In this class, we implement the transformation for Fortran intrinsic SUM(:) We support two ways of invoking the function - by providing array name and array subscript. @@ -561,8 +799,8 @@ def _result_init_value(self): def _result_update_op(self): return "+" -class Product(LoopBasedReplacement): +class Product(LoopBasedReplacement): """ In this class, we implement the transformation for Fortran intrinsic PRODUCT(:) We support two ways of invoking the function - by providing array name and array subscript. @@ -584,6 +822,7 @@ def _result_init_value(self): def _result_update_op(self): return "*" + class AnyAllCountTransformation(LoopBasedReplacementTransformation): def _initialize(self): @@ -601,7 +840,7 @@ def _update_result_type(self, var: ast_internal_classes.Name_Node): Theoretically, we should return LOGICAL for ANY and ALL, but we no longer use booleans on DaCe side. """ - var_decl = self.scope_vars.get_var(var.parent, var.name) + var_decl = self.get_var_declaration(var.parent, var) var.type = "INTEGER" var_decl.type = "INTEGER" @@ -623,16 +862,21 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): else: self.first_array, self.second_array, self.cond = self._parse_binary_op(node, arg) - def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + assert self.first_array is not None + + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, + new_func_body: List[ast_internal_classes.FNode]): rangeslen_left = [] - par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], rangeslen_left, self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.first_array, self.loop_ranges, rangeslen_left, self.count, new_func_body, + self.scope_vars, self.ast.structures, True) if self.second_array is None: return loop_ranges_right = [] rangeslen_right = [] - par_Decl_Range_Finder(self.second_array, loop_ranges_right, [], rangeslen_right, self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.second_array, loop_ranges_right, rangeslen_right, self.count, new_func_body, + self.scope_vars, self.ast.structures, True) for left_len, right_len in zip(rangeslen_left, rangeslen_right): if left_len != right_len: @@ -642,7 +886,6 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n # Thus, we only need to adjust the second array self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges_right) - def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: init_value = self._result_init_value() @@ -666,9 +909,9 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ # TODO: we should make the `break` generation conditional based on the architecture # For parallel maps, we should have no breaks # For sequential loop, we want a break to be faster - #ast_internal_classes.Break_Node( + # ast_internal_classes.Break_Node( # line_number=node.line_number - #) + # ) ]) return ast_internal_classes.If_Stmt_Node( @@ -678,8 +921,8 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ line_number=node.line_number ) -class Any(LoopBasedReplacement): +class Any(LoopBasedReplacement): """ In this class, we implement the transformation for Fortran intrinsic ANY We support three ways of invoking the function - by providing array name, array subscript, @@ -701,13 +944,13 @@ class Any(LoopBasedReplacement): For (2), we reuse the provided binary operation. When the condition is true, we set the value to true and exit. """ + class Transformation(AnyAllCountTransformation): def _result_init_value(self): return "0" def _result_loop_update(self, node: ast_internal_classes.FNode): - return ast_internal_classes.BinOp_Node( lval=copy.deepcopy(node.lval), op="=", @@ -722,21 +965,21 @@ def _loop_condition(self): def func_name() -> str: return "__dace_any" -class All(LoopBasedReplacement): +class All(LoopBasedReplacement): """ In this class, we implement the transformation for Fortran intrinsic ALL. The implementation is very similar to ANY. The main difference is that we initialize the partial result to 1, and set it to 0 if any of the evaluated conditions is false. """ + class Transformation(AnyAllCountTransformation): def _result_init_value(self): return "1" def _result_loop_update(self, node: ast_internal_classes.FNode): - return ast_internal_classes.BinOp_Node( lval=copy.deepcopy(node.lval), op="=", @@ -754,8 +997,8 @@ def _loop_condition(self): def func_name() -> str: return "__dace_all" -class Count(LoopBasedReplacement): +class Count(LoopBasedReplacement): """ In this class, we implement the transformation for Fortran intrinsic COUNT. The implementation is very similar to ANY and ALL. @@ -764,13 +1007,13 @@ class Count(LoopBasedReplacement): We do not support the KIND argument. """ + class Transformation(AnyAllCountTransformation): def _result_init_value(self): return "0" def _result_loop_update(self, node: ast_internal_classes.FNode): - update = ast_internal_classes.BinOp_Node( lval=copy.deepcopy(node.lval), op="+", @@ -804,9 +1047,9 @@ def _update_result_type(self, var: ast_internal_classes.Name_Node): For both MINVAL and MAXVAL, the result type depends on the input variable. """ - input_type = self.scope_vars.get_var(var.parent, self.argument_variable.name.name) + input_type = self.get_var_declaration(var.parent, self.argument_variable) - var_decl = self.scope_vars.get_var(var.parent, var.name) + var_decl = self.get_var_declaration(var.parent, var) var.type = input_type.type var_decl.type = input_type.type @@ -814,6 +1057,10 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): for arg in node.args: + if isinstance(arg, ast_internal_classes.Data_Ref_Node): + self.rvals.append(arg) + continue + array_node = self._parse_array(node, arg) if array_node is not None: @@ -821,14 +1068,16 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): else: raise NotImplementedError("We do not support non-array arguments for MINVAL/MAXVAL") - def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, + new_func_body: List[ast_internal_classes.FNode]): if len(self.rvals) != 1: raise NotImplementedError("Only one array can be summed") self.argument_variable = self.rvals[0] - par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, declaration=True) def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: @@ -860,18 +1109,19 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ line_number=node.line_number ) -class MinVal(LoopBasedReplacement): +class MinVal(LoopBasedReplacement): """ In this class, we implement the transformation for Fortran intrinsic MINVAL. We do not support the MASK and DIM argument. """ + class Transformation(MinMaxValTransformation): def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): - var_decl = self.scope_vars.get_var(array.parent, array.name.name) + var_decl = self.get_var_declaration(array.parent, array) # TODO: this should be used as a call to HUGE fortran_type = var_decl.type @@ -893,17 +1143,17 @@ def func_name() -> str: class MaxVal(LoopBasedReplacement): - """ In this class, we implement the transformation for Fortran intrinsic MAXVAL. We do not support the MASK and DIM argument. """ + class Transformation(MinMaxValTransformation): def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): - var_decl = self.scope_vars.get_var(array.parent, array.name.name) + var_decl = self.get_var_declaration(array.parent, array) # TODO: this should be used as a call to HUGE fortran_type = var_decl.type @@ -923,8 +1173,8 @@ def _condition_op(self): def func_name() -> str: return "__dace_maxval" -class Merge(LoopBasedReplacement): +class Merge(LoopBasedReplacement): class Transformation(LoopBasedReplacementTransformation): def _initialize(self): @@ -957,15 +1207,37 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): # First argument is always an array self.first_array = self._parse_array(node, node.args[0]) - assert self.first_array is not None # Second argument is always an array self.second_array = self._parse_array(node, node.args[1]) - assert self.second_array is not None + + # weird overload of MERGE - passing two scalars + if self.first_array is None or self.second_array is None: + self.uses_scalars = True + self.first_array = node.args[0] + self.second_array = node.args[1] + self.mask_cond = node.args[2] + return + + else: + len_pardecls_first_array = 0 + len_pardecls_second_array = 0 + + for ind in self.first_array.indices: + pardecls = [i for i in mywalk(ind) if isinstance(i, ast_internal_classes.ParDecl_Node)] + len_pardecls_first_array += len(pardecls) + for ind in self.second_array.indices: + pardecls = [i for i in mywalk(ind) if isinstance(i, ast_internal_classes.ParDecl_Node)] + len_pardecls_second_array += len(pardecls) + assert len_pardecls_first_array == len_pardecls_second_array + if len_pardecls_first_array == 0: + self.uses_scalars = True + else: + self.uses_scalars = False # Last argument is either an array or a binary op arg = node.args[2] - array_node = self._parse_array(node, node.args[2]) + array_node = self._parse_array(node, node.args[2], dims_count=len(self.first_array.indices)) if array_node is not None: self.mask_first_array = array_node @@ -980,28 +1252,63 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): self.mask_first_array, self.mask_second_array, self.mask_cond = self._parse_binary_op(node, arg) - def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, + new_func_body: List[ast_internal_classes.FNode]): + + if self.uses_scalars: + self.destination_array = node.lval + return - self.destination_array = self._parse_array(exec_node, node.lval) # The first main argument is an array -> this dictates loop boundaries # Other arrays, regardless if they appear as the second array or mask, need to have the same loop boundary. - par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, True, allow_scalars=True) loop_ranges = [] - par_Decl_Range_Finder(self.second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.second_array, loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, True, allow_scalars=True) self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges) - par_Decl_Range_Finder(self.destination_array, [], [], [], self.count, new_func_body, self.scope_vars, True) + # parse destination + + assert isinstance(node.lval, ast_internal_classes.Name_Node) + + array_decl = self.get_var_declaration(exec_node.parent, node.lval) + if array_decl.sizes is None: + + # for destination array, sizes might be unknown when we use arg extractor + # in that situation, we look at the size of the first argument + dims = len(self.first_array.indices) + else: + dims = len(array_decl.sizes) + self.destination_array = ast_internal_classes.Array_Subscript_Node( + name=node.lval, parent=node.lval.parent, type='VOID', + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims + ) + + # type inference! this is necessary when the destination array is + # not known exactly, e.g., in recursive calls. + if array_decl.sizes is None: + + first_input = self.get_var_declaration(node.parent, node.rval.args[0]) + array_decl.sizes = copy.deepcopy(first_input.sizes) + array_decl.offsets = [1] * len(array_decl.sizes) + array_decl.type = first_input.type + + par_Decl_Range_Finder(self.destination_array, [], [], self.count, + new_func_body, self.scope_vars, self.ast.structures, True) if self.mask_first_array is not None: loop_ranges = [] - par_Decl_Range_Finder(self.mask_first_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.mask_first_array, loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, True, allow_scalars=True) self._adjust_array_ranges(node, self.mask_first_array, self.loop_ranges, loop_ranges) if self.mask_second_array is not None: loop_ranges = [] - par_Decl_Range_Finder(self.mask_second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + par_Decl_Range_Finder(self.mask_second_array, loop_ranges, [], self.count, new_func_body, + self.scope_vars, self.ast.structures, True, allow_scalars=True) self._adjust_array_ranges(node, self.mask_second_array, self.loop_ranges, loop_ranges) def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_internal_classes.BinOp_Node]: @@ -1046,8 +1353,99 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ line_number=node.line_number ) -class MathFunctions(IntrinsicTransformation): +class IntrinsicSDFGTransformation(xf.SingleStateTransformation): + array1 = xf.PatternNode(nodes.AccessNode) + array2 = xf.PatternNode(nodes.AccessNode) + tasklet = xf.PatternNode(nodes.Tasklet) + out = xf.PatternNode(nodes.AccessNode) + + def blas_dot(self, state: SDFGState, sdfg: SDFG): + dot_libnode(None, sdfg, state, self.array1.data, self.array2.data, self.out.data) + + def blas_matmul(self, state: SDFGState, sdfg: SDFG): + gemm_libnode( + None, + sdfg, + state, + self.array1.data, + self.array2.data, + self.out.data, + 1.0, + 0.0, + False, + False + ) + + def transpose(self, state: SDFGState, sdfg: SDFG): + + input_arr = state.add_read(self.array1.data) + res = state.add_write(self.out.data) + + libnode = Transpose("transpose", dtype=sdfg.arrays[self.array1.data].dtype) + state.add_node(libnode) + + state.add_edge(input_arr, None, libnode, "_inp", sdfg.make_array_memlet(self.array1.data)) + state.add_edge(libnode, "_out", res, None, sdfg.make_array_memlet(self.out.data)) + + LIBRARY_NODE_TRANSFORMATIONS = { + "__dace_blas_dot": blas_dot, + "__dace_transpose": transpose, + "__dace_matmul": blas_matmul + } + + @classmethod + def expressions(cls): + + graphs = [] + + # Match tasklets with two inputs, like dot + g = OrderedDiGraph() + g.add_node(cls.array1) + g.add_node(cls.array2) + g.add_node(cls.tasklet) + g.add_node(cls.out) + g.add_edge(cls.array1, cls.tasklet, None) + g.add_edge(cls.array2, cls.tasklet, None) + g.add_edge(cls.tasklet, cls.out, None) + graphs.append(g) + + # Match tasklets with one input, like transpose + g = OrderedDiGraph() + g.add_node(cls.array1) + g.add_node(cls.tasklet) + g.add_node(cls.out) + g.add_edge(cls.array1, cls.tasklet, None) + g.add_edge(cls.tasklet, cls.out, None) + graphs.append(g) + + return graphs + + def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + + import ast + for node in ast.walk(self.tasklet.code.code[0]): + if isinstance(node, ast.Call): + if node.func.id in self.LIBRARY_NODE_TRANSFORMATIONS: + self.func = self.LIBRARY_NODE_TRANSFORMATIONS[node.func.id] + return True + + return False + + def apply(self, state: SDFGState, sdfg: SDFG): + + self.func(self, state, sdfg) + + for in_edge in state.in_edges(self.tasklet): + state.remove_memlet_path(in_edge) + + for in_edge in state.out_edges(self.tasklet): + state.remove_memlet_path(in_edge) + + state.remove_node(self.tasklet) + + +class MathFunctions(IntrinsicTransformation): MathTransformation = namedtuple("MathTransformation", "function return_type") MathReplacement = namedtuple("MathReplacement", "function replacement_function return_type") @@ -1065,7 +1463,8 @@ def generate_scale(arg: ast_internal_classes.Call_Expr_Node): name=ast_internal_classes.Name_Node(name="pow"), type="INTEGER", args=[const_two, i], - line_number=line + line_number=line, + subroutine=False, ) mult = ast_internal_classes.BinOp_Node( @@ -1078,6 +1477,10 @@ def generate_scale(arg: ast_internal_classes.Call_Expr_Node): # pack it into parentheses, just to be sure return ast_internal_classes.Parenthesis_Expr_Node(expr=mult) + def generate_epsilon(arg: ast_internal_classes.Call_Expr_Node): + ret_val = sys.float_info.epsilon + return ast_internal_classes.Real_Literal_Node(value=str(ret_val)) + def generate_aint(arg: ast_internal_classes.Call_Expr_Node): # The call to AINT can contain a second KIND parameter @@ -1098,12 +1501,21 @@ def generate_aint(arg: ast_internal_classes.Call_Expr_Node): return arg + @staticmethod + def _initialize_transformations(): + # dictionary comprehension cannot access class members + ret = {} + for name, value in IntrinsicSDFGTransformation.INTRINSIC_TRANSFORMATIONS.items(): + ret[name] = MathFunctions.MathTransformation(value, "FIRST_ARG") + return ret + INTRINSIC_TO_DACE = { "MIN": MathTransformation("min", "FIRST_ARG"), "MAX": MathTransformation("max", "FIRST_ARG"), "SQRT": MathTransformation("sqrt", "FIRST_ARG"), "ABS": MathTransformation("abs", "FIRST_ARG"), "EXP": MathTransformation("exp", "FIRST_ARG"), + "EPSILON": MathReplacement(None, generate_epsilon, "FIRST_ARG"), # Documentation states that the return type of LOG is always REAL, # but the kind is the same as of the first argument. # However, we already replaced kind with types used in DaCe. @@ -1139,27 +1551,34 @@ def generate_aint(arg: ast_internal_classes.Call_Expr_Node): "ASIN": MathTransformation("asin", "FIRST_ARG"), "ACOS": MathTransformation("acos", "FIRST_ARG"), "ATAN": MathTransformation("atan", "FIRST_ARG"), - "ATAN2": MathTransformation("atan2", "FIRST_ARG") + "ATAN2": MathTransformation("atan2", "FIRST_ARG"), + "DOT_PRODUCT": MathTransformation("__dace_blas_dot", "FIRST_ARG"), + "TRANSPOSE": MathTransformation("__dace_transpose", "FIRST_ARG"), + "MATMUL": MathTransformation("__dace_matmul", "FIRST_ARG"), + "IBSET": MathTransformation("bitwise_set", "INTEGER"), + "IEOR": MathTransformation("bitwise_xor", "INTEGER"), + "ISHFT": MathTransformation("bitwise_shift", "INTEGER"), + "IBCLR": MathTransformation("bitwise_clear", "INTEGER"), + "BTEST": MathTransformation("bitwise_test", "INTEGER"), + "IBITS": MathTransformation("bitwise_extract", "INTEGER"), + "IAND": MathTransformation("bitwise_and", "INTEGER") } class TypeTransformer(IntrinsicNodeTransformer): def func_type(self, node: ast_internal_classes.Call_Expr_Node): - # take the first arg arg = node.args[0] - if isinstance(arg, ast_internal_classes.Real_Literal_Node): - return 'REAL' - elif isinstance(arg, ast_internal_classes.Int_Literal_Node): - return 'INTEGER' - elif isinstance(arg, ast_internal_classes.Call_Expr_Node): + if isinstance(arg, (ast_internal_classes.Real_Literal_Node, ast_internal_classes.Double_Literal_Node, + ast_internal_classes.Int_Literal_Node, ast_internal_classes.Call_Expr_Node, + ast_internal_classes.BinOp_Node, ast_internal_classes.UnOp_Node)): return arg.type - elif isinstance(arg, ast_internal_classes.Name_Node): - input_type = self.scope_vars.get_var(node.parent, arg.name) - return input_type.type + elif isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Array_Subscript_Node)): + return self.get_var_declaration(node.parent, arg).type + elif isinstance(arg, ast_internal_classes.Data_Ref_Node): + return self._parse_struct_ref(arg).type else: - input_type = self.scope_vars.get_var(node.parent, arg.name.name) - return input_type.type + raise NotImplementedError(type(arg)) def replace_call(self, old_call: ast_internal_classes.Call_Expr_Node, new_call: ast_internal_classes.FNode): @@ -1185,7 +1604,7 @@ def replace_call(self, old_call: ast_internal_classes.Call_Expr_Node, new_call: raise NotImplementedError() def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): - + if not isinstance(binop_node.rval, ast_internal_classes.Call_Expr_Node): return binop_node @@ -1198,15 +1617,19 @@ def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): # Visit all children before we expand this call. # We need that to properly get the type. + new_args = [] for arg in node.args: - self.visit(arg) + new_args.append(self.visit(arg)) + node.args = new_args - return_type = None - input_type = None input_type = self.func_type(node) + if input_type == 'VOID': + #assert input_type != 'VOID', f"Unexpected void input at line number: {node.line_number}" + raise NeedsTypeInferenceException(func_name, node.line_number) replacement_rule = MathFunctions.INTRINSIC_TO_DACE[func_name] if isinstance(replacement_rule, dict): + replacement_rule = replacement_rule[input_type] if replacement_rule.return_type == "FIRST_ARG": return_type = input_type @@ -1216,20 +1639,16 @@ def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): if isinstance(replacement_rule, MathFunctions.MathTransformation): node.name = ast_internal_classes.Name_Node(name=replacement_rule.function) node.type = return_type - else: binop_node.rval = replacement_rule.replacement_function(node) # replace types of return variable - LHS of the binary operator var = binop_node.lval - name = None - if isinstance(var.name, ast_internal_classes.Name_Node): - name = var.name.name - else: - name = var.name - var_decl = self.scope_vars.get_var(var.parent, name) - var.type = input_type - var_decl.type = input_type + if isinstance(var, (ast_internal_classes.Name_Node, ast_internal_classes.Data_Ref_Node, + ast_internal_classes.Array_Subscript_Node)): + var_decl = self.get_var_declaration(var.parent, var) + var.type = return_type + var_decl.type = return_type return binop_node @@ -1272,8 +1691,8 @@ def has_transformation() -> bool: def get_transformation() -> TypeTransformer: return MathFunctions.TypeTransformer() -class FortranIntrinsics: +class FortranIntrinsics: IMPLEMENTATIONS_AST = { "SUM": Sum, "PRODUCT": Product, @@ -1282,24 +1701,28 @@ class FortranIntrinsics: "ALL": All, "MINVAL": MinVal, "MAXVAL": MaxVal, - "MERGE": Merge + "MERGE": Merge, } + # All functions return an array + # Our call extraction transformation only supports scalars EXEMPTED_FROM_CALL_EXTRACTION = [ - Merge + "__dace_TRANSPOSE", + "__dace_MATMUL", ] def __init__(self): - self._transformations_to_run = set() + self._transformations_to_run = {} - def transformations(self) -> Set[Type[NodeTransformer]]: - return self._transformations_to_run + def transformations(self) -> List[NodeTransformer]: + return list(self._transformations_to_run.values()) @staticmethod def function_names() -> List[str]: # list of all functions that are created by initial transformation, before doing full replacement # this prevents other parser components from replacing our function calls with array subscription nodes - return [*list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()), *MathFunctions.temporary_functions(), *DirectReplacement.temporary_functions()] + return [*list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()), *MathFunctions.temporary_functions(), + *DirectReplacement.temporary_functions()] @staticmethod def retained_function_names() -> List[str]: @@ -1308,37 +1731,64 @@ def retained_function_names() -> List[str]: @staticmethod def call_extraction_exemptions() -> List[str]: - return [ - *[func.Transformation.func_name() for func in FortranIntrinsics.EXEMPTED_FROM_CALL_EXTRACTION] - #*MathFunctions.temporary_functions() - ] + return FortranIntrinsics.EXEMPTED_FROM_CALL_EXTRACTION def replace_function_name(self, node: FASTNode) -> ast_internal_classes.Name_Node: func_name = node.string replacements = { "SIGN": "__dace_sign", + # TODO implement and categorize the intrinsic functions below + "SPREAD": "__dace_spread", + "TRIM": "__dace_trim", + "LEN_TRIM": "__dace_len_trim", + "ASSOCIATED": "__dace_associated", + "MAXLOC": "__dace_maxloc", + "FRACTION": "__dace_fraction", + "NEW_LINE": "__dace_new_line", + "PRECISION": "__dace_precision", + "MINLOC": "__dace_minloc", + "LEN": "__dace_len", + "SCAN": "__dace_scan", + "RANDOM_SEED": "__dace_random_seed", + "RANDOM_NUMBER": "__dace_random_number", + "DATE_AND_TIME": "__dace_date_and_time", + "RESHAPE": "__dace_reshape", } + if func_name in replacements: return ast_internal_classes.Name_Node(name=replacements[func_name]) elif DirectReplacement.replacable_name(func_name): + if DirectReplacement.has_transformation(func_name): - self._transformations_to_run.add(DirectReplacement.get_transformation()) + # self._transformations_to_run.add(DirectReplacement.get_transformation()) + transformation = DirectReplacement.get_transformation() + if transformation.func_name() not in self._transformations_to_run: + self._transformations_to_run[transformation.func_name()] = transformation + return DirectReplacement.replace_name(func_name) elif MathFunctions.replacable(func_name): - self._transformations_to_run.add(MathFunctions.get_transformation()) + + transformation = MathFunctions.get_transformation() + if transformation.func_name() not in self._transformations_to_run: + self._transformations_to_run[transformation.func_name()] = transformation + return MathFunctions.replace(func_name) if self.IMPLEMENTATIONS_AST[func_name].has_transformation(): if hasattr(self.IMPLEMENTATIONS_AST[func_name], "Transformation"): - self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].Transformation()) + transformation = self.IMPLEMENTATIONS_AST[func_name].Transformation() else: - self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].get_transformation(func_name)) + transformation = self.IMPLEMENTATIONS_AST[func_name].get_transformation(func_name) + + if transformation.func_name() not in self._transformations_to_run: + self._transformations_to_run[transformation.func_name()] = transformation return ast_internal_classes.Name_Node(name=self.IMPLEMENTATIONS_AST[func_name].replaced_name(func_name)) - def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line): + def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, + line, symbols: dict): func_types = { "__dace_sign": "DOUBLE", @@ -1346,13 +1796,13 @@ def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: if name.name in func_types: # FIXME: this will be progressively removed call_type = func_types[name.name] - return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line) + return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line,subroutine=False) elif DirectReplacement.replacable(name.name): - return DirectReplacement.replace(name.name, args, line) + return DirectReplacement.replace(name.name, args, line, symbols) else: # We will do the actual type replacement later # To that end, we need to know the input types - but these we do not know at the moment. return ast_internal_classes.Call_Expr_Node( - name=name, type="VOID", + name=name, type="VOID", subroutine=False, args=args.args, line_number=line ) From 2e4d09f61528ab08a55a8874ac097541791c177e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 08:18:13 +0100 Subject: [PATCH 02/12] Import fortran tests --- tests/fortran/advanced_optional_args_test.py | 92 ++ tests/fortran/allocate_test.py | 22 +- tests/fortran/arg_extract_test.py | 165 ++ tests/fortran/array_attributes_test.py | 319 +++- tests/fortran/array_test.py | 213 ++- ...offset.py => array_to_loop_offset_test.py} | 2 + tests/fortran/ast_desugaring_test.py | 1391 +++++++++++++++++ tests/fortran/call_extract_test.py | 6 +- tests/fortran/cond_type_test.py | 67 + tests/fortran/create_internal_ast_test.py | 282 ++++ tests/fortran/empty_test.py | 46 + tests/fortran/fortran_language_test.py | 234 ++- tests/fortran/fortran_loops_test.py | 45 - tests/fortran/fortran_test_helper.py | 293 ++++ tests/fortran/future/fortran_class_test.py | 117 ++ tests/fortran/global_test.py | 124 ++ tests/fortran/ifcycle_test.py | 107 ++ tests/fortran/init_test.py | 113 ++ tests/fortran/intrinsic_all_test.py | 18 +- tests/fortran/intrinsic_any_test.py | 8 +- tests/fortran/intrinsic_basic_test.py | 292 +++- tests/fortran/intrinsic_blas_test.py | 174 +++ tests/fortran/intrinsic_bound_test.py | 429 +++++ tests/fortran/intrinsic_count_test.py | 16 +- tests/fortran/intrinsic_math_test.py | 116 +- tests/fortran/intrinsic_merge_test.py | 240 +++ tests/fortran/intrinsic_minmaxval_test.py | 59 +- tests/fortran/long_tasklet_test.py | 54 + tests/fortran/missing_func_test.py | 146 ++ tests/fortran/multisdfg_construction_test.py | 161 ++ tests/fortran/nested_array_test.py | 101 ++ .../non-interactive/fortran_int_init_test.py | 66 + .../fortran/non-interactive/function_test.py | 409 +++++ .../fortran/non-interactive/pointers_test.py | 81 + .../{ => non-interactive}/view_test.py | 55 +- tests/fortran/offset_normalizer_test.py | 273 +++- tests/fortran/optional_args_test.py | 116 ++ tests/fortran/parent_test.py | 107 +- tests/fortran/pointer_removal_test.py | 208 +++ tests/fortran/prune_test.py | 147 ++ tests/fortran/prune_unused_children_test.py | 769 +++++++++ tests/fortran/ranges_test.py | 535 +++++++ tests/fortran/recursive_ast_improver_test.py | 731 +++++++++ tests/fortran/rename_test.py | 70 + tests/fortran/scope_arrays_test.py | 2 +- tests/fortran/struct_test.py | 115 ++ tests/fortran/tasklet_test.py | 47 + tests/fortran/type_array_test.py | 224 +++ tests/fortran/type_test.py | 658 ++++++++ tests/fortran/while_test.py | 45 + 50 files changed, 9587 insertions(+), 523 deletions(-) create mode 100644 tests/fortran/advanced_optional_args_test.py create mode 100644 tests/fortran/arg_extract_test.py rename tests/fortran/{array_to_loop_offset.py => array_to_loop_offset_test.py} (99%) create mode 100644 tests/fortran/ast_desugaring_test.py create mode 100644 tests/fortran/cond_type_test.py create mode 100644 tests/fortran/create_internal_ast_test.py create mode 100644 tests/fortran/empty_test.py delete mode 100644 tests/fortran/fortran_loops_test.py create mode 100644 tests/fortran/fortran_test_helper.py create mode 100644 tests/fortran/future/fortran_class_test.py create mode 100644 tests/fortran/global_test.py create mode 100644 tests/fortran/ifcycle_test.py create mode 100644 tests/fortran/init_test.py create mode 100644 tests/fortran/intrinsic_blas_test.py create mode 100644 tests/fortran/intrinsic_bound_test.py create mode 100644 tests/fortran/long_tasklet_test.py create mode 100644 tests/fortran/missing_func_test.py create mode 100644 tests/fortran/multisdfg_construction_test.py create mode 100644 tests/fortran/nested_array_test.py create mode 100644 tests/fortran/non-interactive/fortran_int_init_test.py create mode 100644 tests/fortran/non-interactive/function_test.py create mode 100644 tests/fortran/non-interactive/pointers_test.py rename tests/fortran/{ => non-interactive}/view_test.py (69%) create mode 100644 tests/fortran/optional_args_test.py create mode 100644 tests/fortran/pointer_removal_test.py create mode 100644 tests/fortran/prune_test.py create mode 100644 tests/fortran/prune_unused_children_test.py create mode 100644 tests/fortran/ranges_test.py create mode 100644 tests/fortran/recursive_ast_improver_test.py create mode 100644 tests/fortran/rename_test.py create mode 100644 tests/fortran/struct_test.py create mode 100644 tests/fortran/tasklet_test.py create mode 100644 tests/fortran/type_array_test.py create mode 100644 tests/fortran/type_test.py create mode 100644 tests/fortran/while_test.py diff --git a/tests/fortran/advanced_optional_args_test.py b/tests/fortran/advanced_optional_args_test.py new file mode 100644 index 0000000000..860eeb74dd --- /dev/null +++ b/tests/fortran/advanced_optional_args_test.py @@ -0,0 +1,92 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_optional_adv(): + test_string = """ + PROGRAM adv_intrinsic_optional_test_function + implicit none + integer, dimension(4) :: res + integer, dimension(4) :: res2 + integer :: a + CALL intrinsic_optional_test_function(res, res2, a) + end + + SUBROUTINE intrinsic_optional_test_function(res, res2, a) + integer, dimension(4) :: res + integer, dimension(4) :: res2 + integer :: a + integer,dimension(2) :: ret + + CALL intrinsic_optional_test_function2(res, a) + CALL intrinsic_optional_test_function2(res2) + CALL get_indices_c(1, 1, 1, ret(1), ret(2), 1, 2) + + END SUBROUTINE intrinsic_optional_test_function + + SUBROUTINE intrinsic_optional_test_function2(res, a) + integer, dimension(2) :: res + integer, optional :: a + + res(1) = a + + END SUBROUTINE intrinsic_optional_test_function2 + + SUBROUTINE get_indices_c(i_blk, i_startblk, i_endblk, i_startidx, & + i_endidx, irl_start, opt_rl_end) + + + INTEGER, INTENT(IN) :: i_blk ! Current block (variable jb in do loops) + INTEGER, INTENT(IN) :: i_startblk ! Start block of do loop + INTEGER, INTENT(IN) :: i_endblk ! End block of do loop + INTEGER, INTENT(IN) :: irl_start ! refin_ctrl level where do loop starts + + INTEGER, OPTIONAL, INTENT(IN) :: opt_rl_end ! refin_ctrl level where do loop ends + + INTEGER, INTENT(OUT) :: i_startidx, i_endidx ! Start and end indices (jc loop) + + ! Local variables + + INTEGER :: irl_end + + IF (PRESENT(opt_rl_end)) THEN + irl_end = opt_rl_end + ELSE + irl_end = 42 + ENDIF + + IF (i_blk == i_startblk) THEN + i_startidx = 1 + i_endidx = 42 + IF (i_blk == i_endblk) i_endidx = irl_end + ELSE IF (i_blk == i_endblk) THEN + i_startidx = 1 + i_endidx = irl_end + ELSE + i_startidx = 1 + i_endidx = 42 + ENDIF + +END SUBROUTINE get_indices_c + + """ + sources={} + sources["adv_intrinsic_optional_test_function"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_optional_test_function", True,sources=sources) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + res2 = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res, res2=res2, a=5) + + assert res[0] == 5 + assert res2[0] == 0 + +if __name__ == "__main__": + + test_fortran_frontend_optional_adv() diff --git a/tests/fortran/allocate_test.py b/tests/fortran/allocate_test.py index 498c97d932..aecea60269 100644 --- a/tests/fortran/allocate_test.py +++ b/tests/fortran/allocate_test.py @@ -1,23 +1,12 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.common.readfortran import FortranStringReader -from fparser.common.readfortran import FortranFileReader -from fparser.two.parser import ParserFactory -import sys, os import numpy as np import pytest -from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser -from fparser.two.symbol_table import SymbolTable -from dace.sdfg import utils as sdutil - -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +@pytest.mark.skip(reason="This requires Deferred Allocation support on DaCe, which we do not have yet.") def test_fortran_frontend_basic_allocate(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -39,13 +28,12 @@ def test_fortran_frontend_basic_allocate(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "allocate_test") sdfg.simplify(verbose=True) - a = np.full([4,5], 42, order="F", dtype=np.float64) + a = np.full([4, 5], 42, order="F", dtype=np.float64) sdfg(d=a) - assert (a[0,0] == 42) - assert (a[1,0] == 5.5) - assert (a[2,0] == 42) + assert (a[0, 0] == 42) + assert (a[1, 0] == 5.5) + assert (a[2, 0] == 42) if __name__ == "__main__": - test_fortran_frontend_basic_allocate() diff --git a/tests/fortran/arg_extract_test.py b/tests/fortran/arg_extract_test.py new file mode 100644 index 0000000000..f0085d1f1a --- /dev/null +++ b/tests/fortran/arg_extract_test.py @@ -0,0 +1,165 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_arg_extract(): + test_string = """ + PROGRAM arg_extract + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL arg_extract_test_function(d,res) + end + + SUBROUTINE arg_extract_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + + if (MIN(d(1),1) .EQ. 1 ) then + res(1) = 3 + res(2) = 7 + else + res(1) = 5 + res(2) = 10 + endif + + END SUBROUTINE arg_extract_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract", normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [3,7]) + + +def test_fortran_frontend_arg_extract2(): + test_string = """ + PROGRAM arg_extract2 + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL arg_extract2_test_function(d,res) + end + + SUBROUTINE arg_extract2_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + + if (ALLOCATED(res)) then + res(1) = 3 + res(2) = 7 + else + res(1) = 5 + res(2) = 10 + endif + + END SUBROUTINE arg_extract2_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract2", normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [3,7]) + + +def test_fortran_frontend_arg_extract3(): + test_string = """ + PROGRAM arg_extract3 + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL arg_extract3_test_function(d,res) + end + + SUBROUTINE arg_extract3_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + + integer :: jg + logical, dimension(2) :: is_cloud + + jg = 1 + is_cloud(1) = .true. + d(1)=10 + d(2)=20 + res(1) = MERGE(MERGE(d(1), d(2), d(1) < d(2) .AND. is_cloud(jg)), 0.0D0, is_cloud(jg)) + res(2) = 52 + + END SUBROUTINE arg_extract3_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract3", normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [10,52]) + + +def test_fortran_frontend_arg_extract4(): + test_string = """ + PROGRAM arg_extract4 + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL arg_extract4_test_function(d,res) + end + + SUBROUTINE arg_extract4_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + real :: merge_val + real :: merge_val2 + + integer :: jg + logical, dimension(2) :: is_cloud + + jg = 1 + is_cloud(1) = .true. + d(1)=10 + d(2)=20 + merge_val = MERGE(d(1), d(2), d(1) < d(2) .AND. is_cloud(jg)) + merge_val2 = MERGE(merge_val, 0.0D0, is_cloud(jg)) + res(1)=merge_val2 + res(2) = 52 + + END SUBROUTINE arg_extract4_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract4", normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [10,52]) + +if __name__ == "__main__": + + #test_fortran_frontend_arg_extract() + #test_fortran_frontend_arg_extract2() + test_fortran_frontend_arg_extract3() + test_fortran_frontend_arg_extract4() + diff --git a/tests/fortran/array_attributes_test.py b/tests/fortran/array_attributes_test.py index af433905bc..115946d703 100644 --- a/tests/fortran/array_attributes_test.py +++ b/tests/fortran/array_attributes_test.py @@ -2,29 +2,22 @@ import numpy as np -from dace.frontend.fortran import fortran_parser +from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string + def test_fortran_frontend_array_attribute_no_offset(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - PROGRAM index_offset_test - implicit none - double precision, dimension(5) :: d - CALL index_test_function(d) - end - - SUBROUTINE index_test_function(d) - double precision, dimension(5) :: d - - do i=1,5 - d(i) = i * 2.0 - end do - - END SUBROUTINE index_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision, dimension(5) :: d + do i = 1, 5 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) sdfg.simplify(verbose=True) sdfg.compile() @@ -35,31 +28,55 @@ def test_fortran_frontend_array_attribute_no_offset(): a = np.full([5], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(1,5): + for i in range(1, 5): + # offset -1 is already added + assert a[i - 1] == i * 2 + + +def test_fortran_frontend_array_attribute_no_offset_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize) + integer :: arrsize + double precision, dimension(arrsize) :: d + do i = 1, arrsize + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + from dace.symbolic import symbol + assert isinstance(sdfg.data('d').shape[0], symbol) + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + size = 10 + a = np.full([size], 42, order="F", dtype=np.float64) + sdfg(d=a, arrsize=size) + for i in range(1, size): # offset -1 is already added - assert a[i-1] == i * 2 + assert a[i - 1] == i * 2 + def test_fortran_frontend_array_attribute_offset(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - PROGRAM index_offset_test - implicit none - double precision, dimension(50:54) :: d - CALL index_test_function(d) - end - - SUBROUTINE index_test_function(d) - double precision, dimension(50:54) :: d - - do i=50,54 - d(i) = i * 2.0 - end do - - END SUBROUTINE index_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision, dimension(50:54) :: d + do i = 50, 54 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) sdfg.simplify(verbose=True) sdfg.compile() @@ -70,31 +87,89 @@ def test_fortran_frontend_array_attribute_offset(): a = np.full([60], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(50,54): + for i in range(50, 54): # offset -1 is already added - assert a[i-1] == i * 2 + assert a[i - 1] == i * 2 + + +def test_fortran_frontend_array_attribute_offset_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize) + integer :: arrsize + double precision, dimension(arrsize:arrsize + 4) :: d + do i = arrsize, arrsize + 4 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + arrsize = 50 + a = np.full([arrsize + 10], 42, order="F", dtype=np.float64) + sdfg(d=a, arrsize=arrsize) + for i in range(arrsize, arrsize + 4): + # offset -1 is already added + assert a[i - 1] == i * 2 + + +def test_fortran_frontend_array_attribute_offset_symbol2(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + Compared to the previous one, this one should prevent simplification from removing symbols + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize, arrsize2) + integer :: arrsize + integer :: arrsize2 + double precision, dimension(arrsize:arrsize2) :: d + do i = arrsize, arrsize2 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + + arrsize = 50 + arrsize2 = 54 + assert len(sdfg.data('d').shape) == 1 + assert evaluate(sdfg.data('d').shape[0], {'arrsize': arrsize, 'arrsize2': arrsize2}) == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + a = np.full([arrsize + 10], 42, order="F", dtype=np.float64) + sdfg(d=a, arrsize=arrsize, arrsize2=arrsize2) + for i in range(arrsize, arrsize2): + # offset -1 is already added + assert a[i - 1] == i * 2 + def test_fortran_frontend_array_offset(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - PROGRAM index_offset_test - implicit none - double precision d(50:54) - CALL index_test_function(d) - end - - SUBROUTINE index_test_function(d) - double precision d(50:54) - - do i=50,54 - d(i) = i * 2.0 - end do - - END SUBROUTINE index_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(50:54) + do i = 50, 54 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) sdfg.simplify(verbose=True) sdfg.compile() @@ -105,13 +180,139 @@ def test_fortran_frontend_array_offset(): a = np.full([60], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(50,54): + for i in range(50, 54): # offset -1 is already added - assert a[i-1] == i * 2 + assert a[i - 1] == i * 2 -if __name__ == "__main__": +def test_fortran_frontend_array_offset_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + Compared to the previous one, this one should prevent simplification from removing symbols + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize, arrsize2) + integer :: arrsize + integer :: arrsize2 + double precision :: d(arrsize:arrsize2) + do i = arrsize, arrsize2 + d(i) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', False) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + + arrsize = 50 + arrsize2 = 54 + assert len(sdfg.data('d').shape) == 1 + assert evaluate(sdfg.data('d').shape[0], {'arrsize': arrsize, 'arrsize2': arrsize2}) == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + a = np.full([arrsize + 10], 42, order="F", dtype=np.float64) + sdfg(d=a, arrsize=arrsize, arrsize2=arrsize2) + for i in range(arrsize, arrsize2): + # offset -1 is already added + assert a[i - 1] == i * 2 + +def test_fortran_frontend_array_arbitrary(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize, arrsize2) + integer :: arrsize + integer :: arrsize2 + double precision :: d(:, :) + do i = 1, arrsize + d(i, 1) = i*2.0 + end do +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + sdfg.simplify(verbose=True) + sdfg.compile() + + arrsize = 5 + arrsize2 = 10 + a = np.full([arrsize, arrsize2], 42, order="F", dtype=np.float64) + sdfg(d=a, __f2dace_A_d_d_0_s_0=arrsize, arrsize=arrsize, arrsize2=arrsize2) + for i in range(arrsize): + # offset -1 is already added + assert a[i, 0] == (i + 1) * 2 + + +def test_fortran_frontend_array_arbitrary_attribute(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d, arrsize, arrsize2) + integer :: arrsize + integer :: arrsize2 + double precision, dimension(:, :) :: d + do i = 1, arrsize + d(i, 1) = i*2.0 + end do +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + sdfg.simplify(verbose=True) + sdfg.compile() + + arrsize = 5 + arrsize2 = 10 + a = np.full([arrsize, arrsize2], 42, order="F", dtype=np.float64) + sdfg(d=a, __f2dace_A_d_d_0_s_0=arrsize, arrsize=arrsize, arrsize2=arrsize2) + for i in range(arrsize): + # offset -1 is already added + assert a[i, 0] == (i + 1) * 2 + + +def test_fortran_frontend_array_arbitrary_attribute2(): + sources, main = SourceCodeBuilder().add_file(""" +module lib +contains + subroutine main(d, d2) + double precision, dimension(:, :) :: d, d2 + call other(d, d2) + end subroutine main + + subroutine other(d, d2) + double precision, dimension(:, :) :: d, d2 + d(1, 1) = size(d, 1) + d(1, 2) = size(d, 2) + d(1, 3) = size(d2, 1) + d(1, 4) = size(d2, 2) + end subroutine other +end module lib +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'lib.main', normalize_offsets=False) + sdfg.simplify(verbose=True) + sdfg.compile() + + arrsize = 5 + arrsize2 = 10 + arrsize3 = 3 + arrsize4 = 7 + a = np.full([arrsize, arrsize2], 42, order="F", dtype=np.float64) + b = np.full([arrsize3, arrsize4], 42, order="F", dtype=np.float64) + sdfg(d=a, __f2dace_A_d_d_0_s_0=arrsize, __f2dace_A_d_d_1_s_1=arrsize2, + d2=b, __f2dace_A_d2_d_0_s_2=arrsize3, __f2dace_A_d2_d_1_s_3=arrsize4, + arrsize=arrsize, arrsize2=arrsize2, arrsize3=arrsize3, arrsize4=arrsize4) + assert a[1, 1] == arrsize + assert a[1, 2] == arrsize2 + assert a[1, 3] == arrsize3 + assert a[1, 4] == arrsize4 + + +if __name__ == "__main__": test_fortran_frontend_array_offset() test_fortran_frontend_array_attribute_no_offset() test_fortran_frontend_array_attribute_offset() + test_fortran_frontend_array_attribute_no_offset_symbol() + test_fortran_frontend_array_attribute_offset_symbol() + test_fortran_frontend_array_attribute_offset_symbol2() + test_fortran_frontend_array_offset_symbol() + test_fortran_frontend_array_arbitrary() + test_fortran_frontend_array_arbitrary_attribute() + test_fortran_frontend_array_arbitrary_attribute2() diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index d5b8c5d669..61090457d0 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -1,34 +1,25 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np from dace import dtypes, symbolic -from dace.frontend.fortran import fortran_parser +from dace.frontend.fortran.fortran_parser import create_sdfg_from_string from dace.sdfg import utils as sdutil from dace.sdfg.nodes import AccessNode - -from dace.sdfg.state import LoopRegion +from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string def test_fortran_frontend_array_access(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - PROGRAM access_test - implicit none - double precision d(4) - CALL array_access_test_function(d) - end - - SUBROUTINE array_access_test_function(d) - double precision d(4) - - d(2)=5.5 - - END SUBROUTINE array_access_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "array_access_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(4) + d(2) = 5.5 +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -41,27 +32,19 @@ def test_fortran_frontend_array_ranges(): """ Tests that the Fortran frontend can parse multidimenstional arrays with vectorized ranges and that the accessed indices are correct. """ - test_string = """ - PROGRAM ranges_test - implicit none - double precision d(3,4,5) - CALL array_ranges_test_function(d) - end - - SUBROUTINE array_ranges_test_function(d) - double precision d(3,4,5),e(3,4,5),f(3,4,5) - - e(:,:,:)=1.0 - f(:,:,:)=2.0 - f(:,2:4,:)=3.0 - f(1,1,:)=4.0 - d(:,:,:)=e(:,:,:)+f(:,:,:) - d(1,2:4,1)=e(1,2:4,1)*10.0 - d(1,1,1)=SUM(e(:,1,:)) - - END SUBROUTINE array_ranges_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "array_access_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(3, 4, 5), e(3, 4, 5), f(3, 4, 5) + e(:, :, :) = 1.0 + f(:, :, :) = 2.0 + f(:, 2:4, :) = 3.0 + f(1, 1, :) = 4.0 + d(:, :, :) = e(:, :, :) + f(:, :, :) + d(1, 2:4, 1) = e(1, 2:4, 1)*10.0 + d(1, 1, 1) = sum(e(:, 1, :)) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -72,25 +55,40 @@ def test_fortran_frontend_array_ranges(): assert (d[0, 0, 2] == 5) +def test_fortran_frontend_array_multiple_ranges_with_symbols(): + """ + Tests that the Fortran frontend can parse multidimenstional arrays with vectorized ranges and that the accessed indices are correct. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(a, lu, iend, m) + integer, intent(in) :: iend, m + double precision, intent(inout) :: a(iend, m, m), lu(iend, m, m) + lu(1:iend,1:m,1:m) = a(1:iend,1:m,1:m) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') + sdfg.simplify(verbose=True) + sdfg.compile() + + iend, m = 3, 4 + lu = np.full([iend, m, m], 0, order="F", dtype=np.float64) + a = np.full([iend, m, m], 42, order="F", dtype=np.float64) + + sdfg(a=a, lu=lu, iend=iend, m=m) + assert np.allclose(lu, 42) + + def test_fortran_frontend_array_3dmap(): """ Tests that the normalization of multidimensional array indices works correctly. """ - test_string = """ - PROGRAM array_3dmap_test - implicit none - double precision d(4,4,4) - CALL array_3dmap_test_function(d) - end - - SUBROUTINE array_3dmap_test_function(d) - double precision d(4,4,4) - - d(:,:,:)=7 - - END SUBROUTINE array_3dmap_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "array_3dmap_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(4, 4, 4) + d(:, :, :) = 7 +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) sdutil.normalize_offsets(sdfg) from dace.transformation.auto import auto_optimize as aopt @@ -105,21 +103,13 @@ def test_fortran_frontend_twoconnector(): """ Tests that the multiple connectors to one array are handled correctly. """ - test_string = """ - PROGRAM twoconnector_test - implicit none - double precision d(4) - CALL twoconnector_test_function(d) - end - - SUBROUTINE twoconnector_test_function(d) - double precision d(4) - - d(2)=d(1)+d(3) - - END SUBROUTINE twoconnector_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "twoconnector_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(4) + d(2) = d(1) + d(3) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -132,25 +122,17 @@ def test_fortran_frontend_input_output_connector(): """ Tests that the presence of input and output connectors for the same array is handled correctly. """ - test_string = """ - PROGRAM ioc_test - implicit none - double precision d(2,3) - CALL ioc_test_function(d) - end - - SUBROUTINE ioc_test_function(d) - double precision d(2,3) - integer a,b - - a=1 - b=2 - d(:,:)=0.0 - d(a,b)=d(1,1)+5 - - END SUBROUTINE ioc_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "ioc_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + double precision d(2, 3) + integer a, b + a = 1 + b = 2 + d(:, :) = 0.0 + d(a, b) = d(1, 1) + 5 +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) a = np.full([2, 3], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -163,37 +145,28 @@ def test_fortran_frontend_memlet_in_map_test(): """ Tests that no assumption is made where the iteration variable is inside a memlet subset """ - test_string = """ - PROGRAM memlet_range_test - implicit None - REAL INP(100, 10) - REAL OUT(100, 10) - CALL memlet_range_test_routine(INP, OUT) - END PROGRAM - - SUBROUTINE memlet_range_test_routine(INP, OUT) - REAL INP(100, 10) - REAL OUT(100, 10) - DO I=1,100 - CALL inner_loops(INP(I, :), OUT(I, :)) - ENDDO - END SUBROUTINE memlet_range_test_routine - - SUBROUTINE inner_loops(INP, OUT) - REAL INP(10) - REAL OUT(10) - DO J=1,10 - OUT(J) = INP(J) + 1 - ENDDO - END SUBROUTINE inner_loops - - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(INP, OUT) + real INP(100, 10) + real OUT(100, 10) + do I = 1, 100 + call inner_loops(INP(I, :), OUT(I, :)) + end do +end subroutine main + +subroutine inner_loops(INP, OUT) + real INP(10) + real OUT(10) + do J = 1, 10 + OUT(J) = INP(J) + 1 + end do +end subroutine inner_loops +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify() - # Expect that the start block is a loop - loop = sdfg.nodes()[0] - assert isinstance(loop, LoopRegion) - iter_var = symbolic.pystr_to_symbolic(loop.loop_variable) + # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable + assert len(sdfg.out_edges(sdfg.start_state)) == 1 + iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) for state in sdfg.states(): if len(state.nodes()) > 1: @@ -209,10 +182,10 @@ def test_fortran_frontend_memlet_in_map_test(): if __name__ == "__main__": - test_fortran_frontend_array_3dmap() test_fortran_frontend_array_access() test_fortran_frontend_input_output_connector() test_fortran_frontend_array_ranges() + test_fortran_frontend_array_multiple_ranges_with_symbols() test_fortran_frontend_twoconnector() test_fortran_frontend_memlet_in_map_test() diff --git a/tests/fortran/array_to_loop_offset.py b/tests/fortran/array_to_loop_offset_test.py similarity index 99% rename from tests/fortran/array_to_loop_offset.py rename to tests/fortran/array_to_loop_offset_test.py index 5042859f8c..fe16ed1418 100644 --- a/tests/fortran/array_to_loop_offset.py +++ b/tests/fortran/array_to_loop_offset_test.py @@ -17,6 +17,7 @@ def test_fortran_frontend_arr2loop_without_offset(): SUBROUTINE index_test_function(d) double precision, dimension(5,3) :: d + integer :: i do i=1,5 d(i, :) = i * 2.0 @@ -88,6 +89,7 @@ def test_fortran_frontend_arr2loop_2d_offset(): SUBROUTINE index_test_function(d) double precision, dimension(5,7:9) :: d + integer :: i do i=1,5 d(i, :) = i * 2.0 diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py new file mode 100644 index 0000000000..909170dda5 --- /dev/null +++ b/tests/fortran/ast_desugaring_test.py @@ -0,0 +1,1391 @@ +from typing import Dict + +from fparser.common.readfortran import FortranStringReader +from fparser.two.Fortran2003 import Program +from fparser.two.parser import ParserFactory + +from dace.frontend.fortran.ast_desugaring import correct_for_function_calls, deconstruct_enums, \ + deconstruct_interface_calls, deconstruct_procedure_calls, deconstruct_associations, \ + assign_globally_unique_subprogram_names, assign_globally_unique_variable_names, prune_branches, \ + const_eval_nodes +from dace.frontend.fortran.fortran_parser import recursive_ast_improver +from tests.fortran.fortran_test_helper import SourceCodeBuilder + + +def parse_and_improve(sources: Dict[str, str]): + parser = ParserFactory().create(std="f2008") + assert 'main.f90' in sources + reader = FortranStringReader(sources['main.f90']) + ast = parser(reader) + ast = recursive_ast_improver(ast, sources, [], parser) + assert isinstance(ast, Program) + return ast + + +def test_procedure_replacer(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: side + contains + procedure :: area + procedure :: area_alt => area + procedure :: get_area + end type Square +contains + real function area(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + area = m * this%side * this%side + end function area + subroutine get_area(this, a) + implicit none + class(Square), intent(in) :: this + real, intent(out) :: a + a = area(this, 1.0) + end subroutine get_area +end module lib +""").add_file(""" +subroutine main + use lib, only: Square + implicit none + type(Square) :: s + real :: a + + s%side = 1.0 + a = s%area(1.0) + a = s%area_alt(1.0) + call s%get_area(a) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: side + END TYPE Square + CONTAINS + REAL FUNCTION area(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area = m * this % side * this % side + END FUNCTION area + SUBROUTINE get_area(this, a) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(OUT) :: a + a = area(this, 1.0) + END SUBROUTINE get_area +END MODULE lib +SUBROUTINE main + USE lib, ONLY: get_area_deconproc_2 => get_area + USE lib, ONLY: area_deconproc_1 => area + USE lib, ONLY: area_deconproc_0 => area + USE lib, ONLY: Square + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % side = 1.0 + a = area_deconproc_0(s, 1.0) + a = area_deconproc_1(s, 1.0) + CALL get_area_deconproc_2(s, a) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_procedure_replacer_nested(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Value + real :: val + contains + procedure :: get_value + end type Value + type Square + type(Value) :: side + contains + procedure :: get_area + end type Square +contains + real function get_value(this) + implicit none + class(Value), intent(in) :: this + get_value = this%val + end function get_value + real function get_area(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + real :: side + side = this%side%get_value() + get_area = m*side*side + end function get_area +end module lib +""").add_file(""" +subroutine main + use lib, only: Square + implicit none + type(Square) :: s + real :: a + + s%side%val = 1.0 + a = s%get_area(1.0) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Value + REAL :: val + END TYPE Value + TYPE :: Square + TYPE(Value) :: side + END TYPE Square + CONTAINS + REAL FUNCTION get_value(this) + IMPLICIT NONE + CLASS(Value), INTENT(IN) :: this + get_value = this % val + END FUNCTION get_value + REAL FUNCTION get_area(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + REAL :: side + side = get_value(this % side) + get_area = m * side * side + END FUNCTION get_area +END MODULE lib +SUBROUTINE main + USE lib, ONLY: get_area_deconproc_0 => get_area + USE lib, ONLY: Square + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % side % val = 1.0 + a = get_area_deconproc_0(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_procedure_replacer_name_collision_with_exisiting_var(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: side + contains + procedure :: area + end type Square +contains + real function area(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + area = m*this%side*this%side + end function area +end module lib +""").add_file(""" +subroutine main + use lib, only: Square + implicit none + type(Square) :: s + real :: area + + s%side = 1.0 + area = s%area(1.0) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: side + END TYPE Square + CONTAINS + REAL FUNCTION area(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area = m * this % side * this % side + END FUNCTION area +END MODULE lib +SUBROUTINE main + USE lib, ONLY: area_deconproc_0 => area + USE lib, ONLY: Square + IMPLICIT NONE + TYPE(Square) :: s + REAL :: area + s % side = 1.0 + area = area_deconproc_0(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_procedure_replacer_name_collision_with_another_import(): + sources, main = SourceCodeBuilder().add_file(""" +module lib_1 + implicit none + type Square + real :: side + contains + procedure :: area + end type Square +contains + real function area(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + area = m*this%side*this%side + end function area +end module lib_1 +""").add_file(""" +module lib_2 + implicit none + type Circle + real :: rad + contains + procedure :: area + end type Circle +contains + real function area(this, m) + implicit none + class(Circle), intent(in) :: this + real, intent(in) :: m + area = m*this%rad*this%rad + end function area +end module lib_2 +""").add_file(""" +subroutine main + use lib_1, only: Square + use lib_2, only: Circle + implicit none + type(Square) :: s + type(Circle) :: c + real :: area + + s%side = 1.0 + area = s%area(1.0) + c%rad = 1.0 + area = c%area(1.0) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib_1 + IMPLICIT NONE + TYPE :: Square + REAL :: side + END TYPE Square + CONTAINS + REAL FUNCTION area(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area = m * this % side * this % side + END FUNCTION area +END MODULE lib_1 +MODULE lib_2 + IMPLICIT NONE + TYPE :: Circle + REAL :: rad + END TYPE Circle + CONTAINS + REAL FUNCTION area(this, m) + IMPLICIT NONE + CLASS(Circle), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area = m * this % rad * this % rad + END FUNCTION area +END MODULE lib_2 +SUBROUTINE main + USE lib_2, ONLY: area_deconproc_1 => area + USE lib_1, ONLY: area_deconproc_0 => area + USE lib_1, ONLY: Square + USE lib_2, ONLY: Circle + IMPLICIT NONE + TYPE(Square) :: s + TYPE(Circle) :: c + REAL :: area + s % side = 1.0 + area = area_deconproc_0(s, 1.0) + c % rad = 1.0 + area = area_deconproc_1(c, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_generic_replacer(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: side + contains + procedure :: area_real + procedure :: area_integer + generic :: g_area => area_real, area_integer + end type Square +contains + real function area_real(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + area_real = m*this%side*this%side + end function area_real + real function area_integer(this, m) + implicit none + class(Square), intent(in) :: this + integer, intent(in) :: m + area_integer = m*this%side*this%side + end function area_integer +end module lib +""").add_file(""" +subroutine main + use lib, only: Square + implicit none + type(Square) :: s + real :: a + real :: mr = 1.0 + integer :: mi = 1 + + s%side = 1.0 + a = s%g_area(mr) + a = s%g_area(mi) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: side + END TYPE Square + CONTAINS + REAL FUNCTION area_real(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area_real = m * this % side * this % side + END FUNCTION area_real + REAL FUNCTION area_integer(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + INTEGER, INTENT(IN) :: m + area_integer = m * this % side * this % side + END FUNCTION area_integer +END MODULE lib +SUBROUTINE main + USE lib, ONLY: area_integer_deconproc_1 => area_integer + USE lib, ONLY: area_real_deconproc_0 => area_real + USE lib, ONLY: Square + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + REAL :: mr = 1.0 + INTEGER :: mi = 1 + s % side = 1.0 + a = area_real_deconproc_0(s, mr) + a = area_integer_deconproc_1(s, mi) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_association_replacer(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: side + end type Square +contains + real function area(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + area = m*this%side*this%side + end function area +end module lib +""").add_file(""" +subroutine main + use lib, only: Square, area + implicit none + type(Square) :: s + real :: a + + associate(side => s%side) + s%side = 0.5 + side = 1.0 + a = area(s, 1.0) + end associate +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_associations(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: side + END TYPE Square + CONTAINS + REAL FUNCTION area(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + area = m * this % side * this % side + END FUNCTION area +END MODULE lib +SUBROUTINE main + USE lib, ONLY: Square, area + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % side = 0.5 + s % side = 1.0 + a = area(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_association_replacer_array_access(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: sides(2, 2) + contains + procedure :: area => perim + end type Square +contains + real function perim(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + perim = m * sum(this%sides) + end function perim +end module lib +""").add_file(""" +subroutine main + use lib, only: Square, perim + implicit none + type(Square) :: s + real :: a + + associate(sides => s%sides) + s%sides = 0.5 + s%sides(1, 1) = 1.0 + sides(2, 2) = 1.0 + a = perim(s, 1.0) + a = s%area(1.0) + end associate +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_enums(ast) + ast = deconstruct_associations(ast) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: sides(2, 2) + END TYPE Square + CONTAINS + REAL FUNCTION perim(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + perim = m * SUM(this % sides) + END FUNCTION perim +END MODULE lib +SUBROUTINE main + USE lib, ONLY: perim_deconproc_0 => perim + USE lib, ONLY: Square, perim + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % sides = 0.5 + s % sides(1, 1) = 1.0 + s % sides(2, 2) = 1.0 + a = perim(s, 1.0) + a = perim_deconproc_0(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_association_replacer_array_access_within_array_access(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: sides(2, 2) + contains + procedure :: area => perim + end type Square +contains + real function perim(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + perim = m * sum(this%sides) + end function perim +end module lib +""").add_file(""" +subroutine main + use lib, only: Square, perim + implicit none + type(Square) :: s + real :: a + + associate(sides => s%sides(:, 1)) + s%sides = 0.5 + s%sides(1, 1) = 1.0 + sides(2) = 1.0 + a = perim(s, 1.0) + a = s%area(1.0) + end associate +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_associations(ast) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: sides(2, 2) + END TYPE Square + CONTAINS + REAL FUNCTION perim(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + perim = m * SUM(this % sides) + END FUNCTION perim +END MODULE lib +SUBROUTINE main + USE lib, ONLY: perim_deconproc_0 => perim + USE lib, ONLY: Square, perim + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % sides = 0.5 + s % sides(1, 1) = 1.0 + s % sides(2, 1) = 1.0 + a = perim(s, 1.0) + a = perim_deconproc_0(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_uses_allows_indirect_aliasing(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type Square + real :: sides(2, 2) + contains + procedure :: area => perim + end type Square +contains + real function perim(this, m) + implicit none + class(Square), intent(in) :: this + real, intent(in) :: m + perim = m * sum(this%sides) + end function perim +end module lib +""").add_file(""" +module lib2 + use lib + implicit none +end module lib2 +""").add_file(""" +subroutine main + use lib2, only: Square, perim + implicit none + type(Square) :: s + real :: a + + associate(sides => s%sides(:, 1)) + s%sides = 0.5 + s%sides(1, 1) = 1.0 + sides(2) = 1.0 + a = perim(s, 1.0) + a = s%area(1.0) + end associate +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_associations(ast) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: sides(2, 2) + END TYPE Square + CONTAINS + REAL FUNCTION perim(this, m) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this + REAL, INTENT(IN) :: m + perim = m * SUM(this % sides) + END FUNCTION perim +END MODULE lib +MODULE lib2 + USE lib + IMPLICIT NONE +END MODULE lib2 +SUBROUTINE main + USE lib, ONLY: perim_deconproc_0 => perim + USE lib2, ONLY: Square, perim + IMPLICIT NONE + TYPE(Square) :: s + REAL :: a + s % sides = 0.5 + s % sides(1, 1) = 1.0 + s % sides(2, 1) = 1.0 + a = perim(s, 1.0) + a = perim_deconproc_0(s, 1.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_enum_bindings_become_constants(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main + implicit none + integer, parameter :: k = 42 + enum, bind(c) + enumerator :: a, b, c + end enum + enum, bind(c) + enumerator :: d = a, e, f + end enum + enum, bind(c) + enumerator :: g = k, h = k, i = k + 1 + end enum +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_enums(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main + IMPLICIT NONE + INTEGER, PARAMETER :: k = 42 + INTEGER, PARAMETER :: a = 0 + 0 + INTEGER, PARAMETER :: b = 0 + 1 + INTEGER, PARAMETER :: c = 0 + 2 + INTEGER, PARAMETER :: d = a + 0 + INTEGER, PARAMETER :: e = a + 1 + INTEGER, PARAMETER :: f = a + 2 + INTEGER, PARAMETER :: g = k + 0 + INTEGER, PARAMETER :: h = k + 0 + INTEGER, PARAMETER :: i = k + 1 + 0 +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_aliasing_through_module_procedure(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + interface fun + module procedure real_fun + end interface fun +contains + real function real_fun() + implicit none + real_fun = 1.0 + end function real_fun +end module lib +""").add_file(""" +subroutine main + use lib, only: fun + implicit none + real d(4) + d(2) = fun() +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_associations(ast) + ast = correct_for_function_calls(ast) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTERFACE fun + MODULE PROCEDURE real_fun + END INTERFACE fun + CONTAINS + REAL FUNCTION real_fun() + IMPLICIT NONE + real_fun = 1.0 + END FUNCTION real_fun +END MODULE lib +SUBROUTINE main + USE lib, ONLY: fun + IMPLICIT NONE + REAL :: d(4) + d(2) = fun() +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_interface_replacer_with_module_procedures(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + public :: fun + interface fun + module procedure real_fun + end interface fun + interface not_fun + module procedure not_real_fun + end interface not_fun +contains + real function real_fun() + implicit none + real_fun = 1.0 + end function real_fun + subroutine not_real_fun(a) + implicit none + real, intent(out) :: a + a = 1.0 + end subroutine not_real_fun +end module lib +""").add_file(""" +subroutine main + use lib, only: fun, not_fun + implicit none + real d(4) + d(2) = fun() + call not_fun(d(3)) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = deconstruct_interface_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION real_fun() + IMPLICIT NONE + real_fun = 1.0 + END FUNCTION real_fun + SUBROUTINE not_real_fun(a) + IMPLICIT NONE + REAL, INTENT(OUT) :: a + a = 1.0 + END SUBROUTINE not_real_fun +END MODULE lib +SUBROUTINE main + USE lib, ONLY: not_real_fun_deconiface_1 => not_real_fun + USE lib, ONLY: real_fun_deconiface_0 => real_fun + IMPLICIT NONE + REAL :: d(4) + d(2) = real_fun_deconiface_0() + CALL not_real_fun_deconiface_1(d(3)) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_interface_replacer_with_subroutine_decls(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + interface + subroutine fun(z) + implicit none + real, intent(out) :: z + end subroutine fun + end interface +end module lib +""").add_file(""" +subroutine main + use lib, only: no_fun => fun + implicit none + real d(4) + call no_fun(d(3)) +end subroutine main + +subroutine fun(z) + implicit none + real, intent(out) :: z + z = 1.0 +end subroutine fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = deconstruct_interface_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE +END MODULE lib +SUBROUTINE main + IMPLICIT NONE + REAL :: d(4) + CALL fun(d(3)) +END SUBROUTINE main +SUBROUTINE fun(z) + IMPLICIT NONE + REAL, INTENT(OUT) :: z + z = 1.0 +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_interface_replacer_with_optional_args(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + public :: fun + interface fun + module procedure real_fun, integer_fun + end interface fun +contains + real function real_fun(x) + implicit none + real, intent(in), optional :: x + if (.not.(present(x))) then + real_fun = 1.0 + else + real_fun = x + end if + end function real_fun + integer function integer_fun(x) + implicit none + integer, intent(in) :: x + integer_fun = x * 2 + end function integer_fun +end module lib +""").add_file(""" +subroutine main + use lib, only: fun + implicit none + real d(4) + d(2) = fun() + d(3) = fun(x=4) + d(4) = fun(x=5.0) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = deconstruct_interface_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION real_fun(x) + IMPLICIT NONE + REAL, INTENT(IN), OPTIONAL :: x + IF (.NOT. (PRESENT(x))) THEN + real_fun = 1.0 + ELSE + real_fun = x + END IF + END FUNCTION real_fun + INTEGER FUNCTION integer_fun(x) + IMPLICIT NONE + INTEGER, INTENT(IN) :: x + integer_fun = x * 2 + END FUNCTION integer_fun +END MODULE lib +SUBROUTINE main + USE lib, ONLY: real_fun_deconiface_2 => real_fun + USE lib, ONLY: integer_fun_deconiface_1 => integer_fun + USE lib, ONLY: real_fun_deconiface_0 => real_fun + IMPLICIT NONE + REAL :: d(4) + d(2) = real_fun_deconiface_0() + d(3) = integer_fun_deconiface_1(x = 4) + d(4) = real_fun_deconiface_2(x = 5.0) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_interface_replacer_with_keyworded_args(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + public :: fun + interface fun + module procedure real_fun + end interface fun +contains + real function real_fun(w, x, y, z) + implicit none + real, intent(in) :: w + real, intent(in), optional :: x + real, intent(in) :: y + real, intent(in), optional :: z + if (.not.(present(x))) then + real_fun = 1.0 + else + real_fun = w + y + end if + end function real_fun +end module lib +""").add_file(""" +subroutine main + use lib, only: fun + implicit none + real d(3) + d(1) = fun(1.0, 2.0, 3.0, 4.0) ! all present, no keyword + d(2) = fun(y=1.1, w=3.1) ! only required ones, keyworded + d(3) = fun(1.2, 2.2, y=3.2) ! partially keyworded, last optional omitted. +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = deconstruct_interface_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION real_fun(w, x, y, z) + IMPLICIT NONE + REAL, INTENT(IN) :: w + REAL, INTENT(IN), OPTIONAL :: x + REAL, INTENT(IN) :: y + REAL, INTENT(IN), OPTIONAL :: z + IF (.NOT. (PRESENT(x))) THEN + real_fun = 1.0 + ELSE + real_fun = w + y + END IF + END FUNCTION real_fun +END MODULE lib +SUBROUTINE main + USE lib, ONLY: real_fun_deconiface_2 => real_fun + USE lib, ONLY: real_fun_deconiface_1 => real_fun + USE lib, ONLY: real_fun_deconiface_0 => real_fun + IMPLICIT NONE + REAL :: d(3) + d(1) = real_fun_deconiface_0(1.0, 2.0, 3.0, 4.0) + d(2) = real_fun_deconiface_1(y = 1.1, w = 3.1) + d(3) = real_fun_deconiface_2(1.2, 2.2, y = 3.2) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_generic_replacer_deducing_array_types(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type T + real :: val(2, 2) + contains + procedure :: copy_matrix + procedure :: copy_vector + procedure :: copy_scalar + generic :: copy => copy_matrix, copy_vector, copy_scalar + end type T +contains + subroutine copy_scalar(this, m) + implicit none + class(T), intent(in) :: this + real, intent(out) :: m + m = this%val(1, 1) + end subroutine copy_scalar + subroutine copy_vector(this, m) + implicit none + class(T), intent(in) :: this + real, dimension(:), intent(out) :: m + m = this%val(1, 1) + end subroutine copy_vector + subroutine copy_matrix(this, m) + implicit none + class(T), intent(in) :: this + real, dimension(:, :), intent(out) :: m + m = this%val(1, 1) + end subroutine copy_matrix +end module lib +""").add_file(""" +subroutine main + use lib, only: T + implicit none + type(T) :: s, s1 + real, dimension(4, 4) :: a + real :: b(4, 4) + + s%val = 1.0 + call s%copy(a) + call s%copy(a(2, 2)) + call s%copy(b(:, 2)) + call s%copy(b(:, :)) + call s%copy(s1%val(:, 1)) +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = deconstruct_procedure_calls(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: T + REAL :: val(2, 2) + END TYPE T + CONTAINS + SUBROUTINE copy_scalar(this, m) + IMPLICIT NONE + CLASS(T), INTENT(IN) :: this + REAL, INTENT(OUT) :: m + m = this % val(1, 1) + END SUBROUTINE copy_scalar + SUBROUTINE copy_vector(this, m) + IMPLICIT NONE + CLASS(T), INTENT(IN) :: this + REAL, DIMENSION(:), INTENT(OUT) :: m + m = this % val(1, 1) + END SUBROUTINE copy_vector + SUBROUTINE copy_matrix(this, m) + IMPLICIT NONE + CLASS(T), INTENT(IN) :: this + REAL, DIMENSION(:, :), INTENT(OUT) :: m + m = this % val(1, 1) + END SUBROUTINE copy_matrix +END MODULE lib +SUBROUTINE main + USE lib, ONLY: copy_vector_deconproc_4 => copy_vector + USE lib, ONLY: copy_matrix_deconproc_3 => copy_matrix + USE lib, ONLY: copy_vector_deconproc_2 => copy_vector + USE lib, ONLY: copy_scalar_deconproc_1 => copy_scalar + USE lib, ONLY: copy_matrix_deconproc_0 => copy_matrix + USE lib, ONLY: T + IMPLICIT NONE + TYPE(T) :: s, s1 + REAL, DIMENSION(4, 4) :: a + REAL :: b(4, 4) + s % val = 1.0 + CALL copy_matrix_deconproc_0(s, a) + CALL copy_scalar_deconproc_1(s, a(2, 2)) + CALL copy_vector_deconproc_2(s, b(:, 2)) + CALL copy_matrix_deconproc_3(s, b(:, :)) + CALL copy_vector_deconproc_4(s, s1 % val(:, 1)) +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_globally_unique_names(): + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type :: Square + real :: sides(2, 2) + end type Square + integer, parameter :: k = 4 + real :: circle = 2.0_k +contains + real function perim(this, m) + implicit none + class(Square), intent(IN) :: this + real, intent(IN) :: m + perim = m*sum(this%sides) + end function perim + function area(this, m) + implicit none + class(Square), intent(IN) :: this + real, intent(IN) :: m + real, dimension(2, 2) :: area + area = m*sum(this%sides) + end function area +end module lib +""").add_file(""" +subroutine main + use lib + use lib, only: perim + use lib, only: p2 => perim + use lib, only: circle + implicit none + type(Square) :: s + real :: a + integer :: i, j + s%sides = 0.5 + s%sides(1, 1) = 1.0 + s%sides(2, 1) = 1.0 + do i = 1, 2 + do j = 1, 2 + s%sides(i, j) = 7.0 + end do + end do + a = perim(s, 1.0) + a = p2(s, 1.0) + s%sides = area(s, 4.1) + circle = 5.0 +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = assign_globally_unique_subprogram_names(ast, {('main',)}) + ast = assign_globally_unique_variable_names(ast, set()) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: Square + REAL :: sides(2, 2) + END TYPE Square + INTEGER, PARAMETER :: k_deconglobalvar_3 = 4 + REAL :: circle_deconglobalvar_4 = 2.0_k_deconglobalvar_3 + CONTAINS + REAL FUNCTION perim_deconglobalfn_5(this_deconglobalvar_6, m_deconglobalvar_7) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this_deconglobalvar_6 + REAL, INTENT(IN) :: m_deconglobalvar_7 + perim_deconglobalfn_5 = m_deconglobalvar_7 * SUM(this_deconglobalvar_6 % sides) + END FUNCTION perim_deconglobalfn_5 + FUNCTION area_deconglobalfn_8(this_deconglobalvar_9, m_deconglobalvar_10) + IMPLICIT NONE + CLASS(Square), INTENT(IN) :: this_deconglobalvar_9 + REAL, INTENT(IN) :: m_deconglobalvar_10 + REAL, DIMENSION(2, 2) :: area_deconglobalfn_8 + area_deconglobalfn_8 = m_deconglobalvar_10 * SUM(this_deconglobalvar_9 % sides) + END FUNCTION area_deconglobalfn_8 +END MODULE lib +SUBROUTINE main + USE lib, ONLY: circle_deconglobalvar_4 + USE lib, ONLY: area_deconglobalfn_8 + USE lib, ONLY: perim_deconglobalfn_5 + USE lib, ONLY: perim_deconglobalfn_5 + USE lib + IMPLICIT NONE + TYPE(Square) :: s_deconglobalvar_13 + REAL :: a_deconglobalvar_14 + INTEGER :: i_deconglobalvar_15, j_deconglobalvar_16 + s_deconglobalvar_13 % sides = 0.5 + s_deconglobalvar_13 % sides(1, 1) = 1.0 + s_deconglobalvar_13 % sides(2, 1) = 1.0 + DO i_deconglobalvar_15 = 1, 2 + DO j_deconglobalvar_16 = 1, 2 + s_deconglobalvar_13 % sides(i_deconglobalvar_15, j_deconglobalvar_16) = 7.0 + END DO + END DO + a_deconglobalvar_14 = perim_deconglobalfn_5(s_deconglobalvar_13, 1.0) + a_deconglobalvar_14 = perim_deconglobalfn_5(s_deconglobalvar_13, 1.0) + s_deconglobalvar_13 % sides = area_deconglobalfn_8(s_deconglobalvar_13, 4.1) + circle_deconglobalvar_4 = 5.0 +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_branch_pruning(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main + implicit none + integer, parameter :: k = 4 + integer :: a = -1, b = -1 + + if (k < 2) then + a = k + else if (k < 5) then + b = k + else + a = k + b = k + end if +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_branches(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main + IMPLICIT NONE + INTEGER, PARAMETER :: k = 4 + INTEGER :: a = - 1, b = - 1 + b = k +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_constant_resolving_expressions(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main + implicit none + integer, parameter :: k = 8 + integer :: a = -1, b = -1 + real, parameter :: pk = 4.1_k + real(kind=selected_real_kind(5, 5)) :: p = 1.0_k + + if (k < 2) then + a = k + p = k*pk + else if (k < 5) then + b = k + p = p + k*pk + else + a = k + b = k + p = a*p + k*pk + end if +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = const_eval_nodes(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main + IMPLICIT NONE + INTEGER, PARAMETER :: k = 8 + INTEGER :: a = - 1, b = - 1 + REAL, PARAMETER :: pk = 4.1D0 + REAL(KIND = 4) :: p = 1.0D0 + IF (.FALSE.) THEN + a = 8 + p = 32.79999923706055D0 + ELSE IF (.FALSE.) THEN + b = 8 + p = p + 32.79999923706055D0 + ELSE + a = 8 + b = 8 + p = a * p + 32.79999923706055D0 + END IF +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_constant_resolving_non_expressions(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main + implicit none + integer, parameter :: k = 8 + integer :: i + real :: a = 1 + do i = 2, k + a = a + i * k + end do + a = fun(k) + call not_fun(k, a) + contains + real function fun(x) + integer, intent(in) :: x + fun = x * k + end function fun + subroutine not_fun(x, y) + integer, intent(in) :: x + real, intent(out) :: y + y = x * k + end subroutine not_fun +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = correct_for_function_calls(ast) + ast = const_eval_nodes(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main + IMPLICIT NONE + INTEGER, PARAMETER :: k = 8 + INTEGER :: i + REAL :: a = 1 + DO i = 2, 8 + a = a + i * 8 + END DO + a = fun(8) + CALL not_fun(8, a) + CONTAINS + REAL FUNCTION fun(x) + INTEGER, INTENT(IN) :: x + fun = x * 8 + END FUNCTION fun + SUBROUTINE not_fun(x, y) + INTEGER, INTENT(IN) :: x + REAL, INTENT(OUT) :: y + y = x * 8 + END SUBROUTINE not_fun +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() diff --git a/tests/fortran/call_extract_test.py b/tests/fortran/call_extract_test.py index eb1f2ac86d..c004083f31 100644 --- a/tests/fortran/call_extract_test.py +++ b/tests/fortran/call_extract_test.py @@ -17,17 +17,17 @@ def test_fortran_frontend_call_extract(): SUBROUTINE intrinsic_call_extract_test_function(d,res) real, dimension(2) :: d real, dimension(2) :: res - + res(1) = SQRT(SIGN(EXP(d(1)), LOG(d(1)))) res(2) = MIN(SQRT(EXP(d(1))), SQRT(EXP(d(1))) - 1) END SUBROUTINE intrinsic_call_extract_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_call_extract", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_call_extract", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - + input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) diff --git a/tests/fortran/cond_type_test.py b/tests/fortran/cond_type_test.py new file mode 100644 index 0000000000..a395047db1 --- /dev/null +++ b/tests/fortran/cond_type_test.py @@ -0,0 +1,67 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +from dace.transformation.passes.lift_struct_views import LiftStructViews +from dace.transformation import pass_pipeline as ppl + + +def test_fortran_frontend_cond_type(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: id + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL cond_type_test_function(d) + end + + SUBROUTINE cond_type_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: ptr_patch + LOGICAL :: bla=.TRUE. + ptr_patch%w(1,1,1) = 5.5 + ptr_patch%id = 6 + if (ptr_patch%id .GT. 5) then + d(2,1) = 5.5 + ptr_patch%w(1,1,1) + else + d(2,1) = 12 + endif + END SUBROUTINE cond_type_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + +if __name__ == "__main__": + test_fortran_frontend_cond_type() diff --git a/tests/fortran/create_internal_ast_test.py b/tests/fortran/create_internal_ast_test.py new file mode 100644 index 0000000000..47193f3445 --- /dev/null +++ b/tests/fortran/create_internal_ast_test.py @@ -0,0 +1,282 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict + +from dace.frontend.fortran.ast_internal_classes import Program_Node, Main_Program_Node, Subroutine_Subprogram_Node, \ + Module_Node, Specification_Part_Node +from dace.frontend.fortran.ast_transforms import Structures, Structure +from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast +from tests.fortran.fortran_test_helper import SourceCodeBuilder, InternalASTMatcher as M + + +def construct_internal_ast(sources: Dict[str, str]): + assert 'main.f90' in sources + cfg = ParseConfig(sources['main.f90'], sources, []) + iast, prog = create_internal_ast(cfg) + return iast, prog + + +def test_minimal(): + """ + A simple program to just verify that we can produce compilable SDFGs. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + d(2) = 5.5 +end program main +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 +end subroutine fun +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'main_program': M(Main_Program_Node), + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, has_empty_attr={'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m.check(prog) + + +def test_standalone_subroutines(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 +end subroutine fun + +subroutine not_fun(d, val) + implicit none + double precision, intent(in) :: val + double precision, intent(inout) :: d(4) + d(4) = val +end subroutine not_fun +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('not_fun'), + 'args': [M.NAMED('d'), M.NAMED('val')], + }), + ], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m.check(prog) + + +def test_subroutines_from_module(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 + end subroutine fun + + subroutine not_fun(d, val) + implicit none + double precision, intent(in) :: val + double precision, intent(inout) :: d(4) + d(4) = val + end subroutine not_fun +end module lib +""").add_file(""" +program main + use lib + implicit none + double precision :: d(4) + call fun(d) + call not_fun(d, 2.1d0) +end program main +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'main_program': M(Main_Program_Node), + 'modules': [M(Module_Node, has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('not_fun'), + 'args': [M.NAMED('d'), M.NAMED('val')], + }), + ], + }, has_empty_attr={'function_definitions', 'interface_blocks'})], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) + m.check(prog) + + +def test_subroutine_with_local_variable(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + double precision :: e(4) + e(:) = 1.0 + e(2) = 4.2 + d(:) = e(:) +end subroutine fun +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m.check(prog) + + +def test_subroutine_contains_function(): + """ + A function is defined inside a subroutine that calls it. A main program uses the top-level subroutine. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = fun2() + + contains + real function fun2() + implicit none + fun2 = 5.5 + end function fun2 + end subroutine fun +end module lib +""").add_file(""" +program main + use lib, only: fun + implicit none + + double precision d(4) + call fun(d) +end program main +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'main_program': M(Main_Program_Node), + 'modules': [M(Module_Node, has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + }, has_empty_attr={'function_definitions', 'interface_blocks'})], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) + m.check(prog) + + # TODO: We cannot handle during the internal AST construction (it works just fine before during parsing etc.) when a + # subroutine contains other subroutines. This needs to be fixed. + mod = prog.modules[0] + # Where could `fun2`'s definition could be? + assert not mod.function_definitions # Not here! + assert 'fun2' not in [f.name.name for f in mod.subroutine_definitions] # Not here! + fn = mod.subroutine_definitions[0] + assert not hasattr(fn, 'function_definitions') # Not here! + assert not hasattr(fn, 'subroutine_definitions') # Not here! + + +def test_module_contains_types(): + """ + Module has type definition that the program does not use, so it gets pruned. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + type used_type + real :: w(5, 5, 5), z(5) + integer :: a + real :: name + end type used_type +end module lib +""").add_file(""" +program main + implicit none + real :: d(5, 5) + call fun(d) +end program main +subroutine fun(d) + use lib, only : used_type + real d(5, 5) + type(used_type) :: s + s%w(1, 1, 1) = 5.5 + d(2, 1) = 5.5 + s%w(1, 1, 1) +end subroutine fun +""").check_with_gfortran().get() + # Construct + iast, prog = construct_internal_ast(sources) + + # Verify + assert not iast.fortran_intrinsics().transformations() + m = M(Program_Node, has_attr={ + 'main_program': M(Main_Program_Node), + 'modules': [M(Module_Node, has_attr={ + 'specification_part': M(Specification_Part_Node, {'typedecls': M.IGNORE(1)}) + }, has_empty_attr={'function_definitions', 'interface_blocks'})], + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': M(Structures, { + 'structures': {'used_type': M(Structure)}, + }) + }, has_empty_attr={'function_definitions', 'placeholders', 'placeholders_offsets'}) + m.check(prog) diff --git a/tests/fortran/empty_test.py b/tests/fortran/empty_test.py new file mode 100644 index 0000000000..7e07aa09df --- /dev/null +++ b/tests/fortran/empty_test.py @@ -0,0 +1,46 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder + + +def test_fortran_frontend_empty(): + """ + Test that empty subroutines and functions are correctly parsed. + """ + sources, main = SourceCodeBuilder().add_file(""" +module module_mpi + integer :: process_mpi_all_size = 0 +contains + logical function fun_with_no_arguments() + fun_with_no_arguments = (process_mpi_all_size <= 1) + end function fun_with_no_arguments +end module module_mpi +""").add_file(""" +subroutine main(d) + use module_mpi, only: fun_with_no_arguments + double precision d(2, 3) + logical :: bla = .false. + + bla = fun_with_no_arguments() + if (bla) then + d(1, 1) = 0 + d(1, 2) = 5 + d(2, 3) = 0 + else + d(1, 2) = 1 + end if +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') + sdfg.simplify(verbose=True) + a = np.full([2, 3], 42, order="F", dtype=np.float64) + sdfg(d=a, process_mpi_all_size=0) + assert (a[0, 0] == 0) + assert (a[0, 1] == 5) + assert (a[1, 2] == 0) + + +if __name__ == "__main__": + test_fortran_frontend_empty() diff --git a/tests/fortran/fortran_language_test.py b/tests/fortran/fortran_language_test.py index 840f0bda0e..e784503e57 100644 --- a/tests/fortran/fortran_language_test.py +++ b/tests/fortran/fortran_language_test.py @@ -1,23 +1,17 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np from dace.frontend.fortran import fortran_parser +from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder def test_fortran_frontend_real_kind_selector(): """ Tests that the size intrinsics are correctly parsed and translated to DaCe. """ - test_string = """ -program real_kind_selector_test - implicit none - integer, parameter :: JPRB = selected_real_kind(13, 300) - real(KIND=JPRB) d(4) - call real_kind_selector_test_function(d) -end - -subroutine real_kind_selector_test_function(d) + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) implicit none integer, parameter :: JPRB = selected_real_kind(13, 300) integer, parameter :: JPIM = selected_int_kind(9) @@ -26,10 +20,9 @@ def test_fortran_frontend_real_kind_selector(): i = 7 d(2) = 5.5 + i - -end subroutine real_kind_selector_test_function -""" - sdfg = fortran_parser.create_sdfg_from_string(test_string, "real_kind_selector_test") +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -42,33 +35,27 @@ def test_fortran_frontend_if1(): """ Tests that the if/else construct is correctly parsed and translated to DaCe. """ - test_string = """ - PROGRAM if1_test - implicit none - double precision d(3,4,5) - CALL if1_test_function(d) - end - - SUBROUTINE if1_test_function(d) - double precision d(3,4,5),ZFAC(10) - integer JK,JL,RTT,NSSOPT - integer ZTP1(10,10) - JL=1 - JK=1 - ZTP1(JL,JK)=1.0 - RTT=2 - NSSOPT=1 - - IF (ZTP1(JL,JK)>=RTT .OR. NSSOPT==0) THEN - ZFAC(1) = 1.0 - ELSE - ZFAC(1) = 2.0 - ENDIF - d(1,1,1)=ZFAC(1) - - END SUBROUTINE if1_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "if1_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + double precision d(3, 4, 5), ZFAC(10) + integer JK, JL, RTT, NSSOPT + integer ZTP1(10, 10) + JL = 1 + JK = 1 + ZTP1(JL, JK) = 1.0 + RTT = 2 + NSSOPT = 1 + + if (ZTP1(JL, JK) >= RTT .or. NSSOPT == 0) then + ZFAC(1) = 1.0 + else + ZFAC(1) = 2.0 + end if + d(1, 1, 1) = ZFAC(1) +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -79,19 +66,11 @@ def test_fortran_frontend_loop1(): """ Tests that the loop construct is correctly parsed and translated to DaCe. """ - - test_string = """ -program loop1_test - implicit none - logical :: d(3, 4, 5) - call loop1_test_function(d) -end - -subroutine loop1_test_function(d) - logical :: d(3, 4, 5), ZFAC(10) + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + logical d(3, 4, 5), ZFAC(10) integer :: a, JK, JL, JM integer, parameter :: KLEV = 10, N = 10, NCLV = 3 - integer :: tmp double precision :: RLMIN, ZVQX(NCLV) logical :: LLCOOLJ, LLFALL(NCLV) @@ -102,42 +81,24 @@ def test_fortran_frontend_loop1(): if (ZVQX(JM) > 0.0) LLFALL(JM) = .true. ! falling species end do - do I = 1, 3 - do J = 1, 4 - do K = 1, 5 - tmp = I+J+K-3 - tmp = mod(tmp, 2) - if (tmp == 1) then - d(I, J, K) = LLFALL(2) - else - d(I, J, K) = LLFALL(1) - end if - end do - end do - end do -end subroutine loop1_test_function -""" - sdfg = fortran_parser.create_sdfg_from_string(test_string, "loop1_test") + d(1, 1, 1) = LLFALL(1) + d(1, 1, 2) = LLFALL(2) +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) - d = np.full([3, 4, 5], 42, order="F", dtype=np.int32) + d = np.full([3, 4, 5], 1, order="F", dtype=np.int32) sdfg(d=d) - # Verify the checkerboard pattern. - assert all(bool(v) == ((i+j+k) % 2 == 1) for (i, j, k), v in np.ndenumerate(d)) + assert (d[0, 0, 0] == 0) + assert (d[0, 0, 1] == 1) def test_fortran_frontend_function_statement1(): """ Tests that the function statement are correctly removed recursively. """ - - test_string = """ -program function_statement1_test - implicit none - double precision d(3, 4, 5) - call function_statement1_test_function(d) -end - -subroutine function_statement1_test_function(d) + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) double precision d(3, 4, 5) double precision :: PTARE, RTT(2), FOEDELTA, FOELDCP double precision :: RALVDCP(2), RALSDCP(2), RES @@ -151,9 +112,9 @@ def test_fortran_frontend_function_statement1(): d(1, 1, 1) = FOELDCP(3.d0) RES = FOELDCP(3.d0) d(1, 1, 2) = RES -end subroutine function_statement1_test_function -""" - sdfg = fortran_parser.create_sdfg_from_string(test_string, "function_statement1_test") +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -165,27 +126,22 @@ def test_fortran_frontend_pow1(): """ Tests that the power intrinsic is correctly parsed and translated to DaCe. (should become a*a) """ - test_string = """ - PROGRAM pow1_test - implicit none - double precision d(3,4,5) - CALL pow1_test_function(d) - end - - SUBROUTINE pow1_test_function(d) - double precision d(3,4,5) - double precision :: ZSIGK(2), ZHRC(2),RAMID(2) - - ZSIGK(1)=4.8 - RAMID(1)=0.0 - ZHRC(1)=12.34 - IF(ZSIGK(1) > 0.8) THEN - ZHRC(1)=RAMID(1)+(1.0-RAMID(1))*((ZSIGK(1)-0.8)/0.2)**2 - ENDIF - d(1,1,2)=ZHRC(1) - END SUBROUTINE pow1_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "pow1_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + double precision d(3, 4, 5) + double precision :: ZSIGK(2), ZHRC(2), RAMID(2) + + ZSIGK(1) = 4.8 + RAMID(1) = 0.0 + ZHRC(1) = 12.34 + if (ZSIGK(1) > 0.8) then + ZHRC(1) = RAMID(1) + (1.0 - RAMID(1))*((ZSIGK(1) - 0.8)/0.2)**2 + end if + d(1, 1, 2) = ZHRC(1) +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -196,28 +152,22 @@ def test_fortran_frontend_pow2(): """ Tests that the power intrinsic is correctly parsed and translated to DaCe (this time it's p sqrt p). """ - - test_string = """ - PROGRAM pow2_test - implicit none - double precision d(3,4,5) - CALL pow2_test_function(d) - end - - SUBROUTINE pow2_test_function(d) - double precision d(3,4,5) - double precision :: ZSIGK(2), ZHRC(2),RAMID(2) - - ZSIGK(1)=4.8 - RAMID(1)=0.0 - ZHRC(1)=12.34 - IF(ZSIGK(1) > 0.8) THEN - ZHRC(1)=RAMID(1)+(1.0-RAMID(1))*((ZSIGK(1)-0.8)/0.01)**1.5 - ENDIF - d(1,1,2)=ZHRC(1) - END SUBROUTINE pow2_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "pow2_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + double precision d(3, 4, 5) + double precision :: ZSIGK(2), ZHRC(2), RAMID(2) + + ZSIGK(1) = 4.8 + RAMID(1) = 0.0 + ZHRC(1) = 12.34 + if (ZSIGK(1) > 0.8) then + ZHRC(1) = RAMID(1) + (1.0 - RAMID(1))*((ZSIGK(1) - 0.8)/0.01)**1.5 + end if + d(1, 1, 2) = ZHRC(1) +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -228,24 +178,19 @@ def test_fortran_frontend_sign1(): """ Tests that the sign intrinsic is correctly parsed and translated to DaCe. """ - test_string = """ - PROGRAM sign1_test - implicit none - double precision d(3,4,5) - CALL sign1_test_function(d) - end - - SUBROUTINE sign1_test_function(d) - double precision d(3,4,5) - double precision :: ZSIGK(2), ZHRC(2),RAMID(2) - - ZSIGK(1)=4.8 - RAMID(1)=0.0 - ZHRC(1)=-12.34 - d(1,1,2)=SIGN(ZSIGK(1),ZHRC(1)) - END SUBROUTINE sign1_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "sign1_test") + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + double precision d(3, 4, 5) + double precision :: ZSIGK(2), ZHRC(2), RAMID(2) + + ZSIGK(1) = 4.8 + RAMID(1) = 0.0 + ZHRC(1) = -12.34 + d(1, 1, 2) = sign(ZSIGK(1), ZHRC(1)) +end subroutine main +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) sdfg(d=d) @@ -257,7 +202,6 @@ def test_fortran_frontend_sign1(): test_fortran_frontend_if1() test_fortran_frontend_loop1() test_fortran_frontend_function_statement1() - test_fortran_frontend_pow1() test_fortran_frontend_pow2() test_fortran_frontend_sign1() diff --git a/tests/fortran/fortran_loops_test.py b/tests/fortran/fortran_loops_test.py deleted file mode 100644 index b18a5e36e8..0000000000 --- a/tests/fortran/fortran_loops_test.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. - -import numpy as np - -from dace.frontend.fortran import fortran_parser - -def test_fortran_frontend_loop_region_basic_loop(): - test_name = "loop_test" - test_string = """ - PROGRAM loop_test_program - implicit none - double precision a(10,10) - double precision b(10,10) - double precision c(10,10) - - CALL loop_test_function(a,b,c) - end - - SUBROUTINE loop_test_function(a,b,c) - double precision :: a(10,10) - double precision :: b(10,10) - double precision :: c(10,10) - - INTEGER :: JK,JL - DO JK=1,10 - DO JL=1,10 - c(JK,JL) = a(JK,JL) + b(JK,JL) - ENDDO - ENDDO - end SUBROUTINE loop_test_function - """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_explicit_cf=True) - - a_test = np.full([10, 10], 2, order="F", dtype=np.float64) - b_test = np.full([10, 10], 3, order="F", dtype=np.float64) - c_test = np.zeros([10, 10], order="F", dtype=np.float64) - sdfg(a=a_test, b=b_test, c=c_test) - - validate = np.full([10, 10], 5, order="F", dtype=np.float64) - - assert np.allclose(c_test, validate) - - -if __name__ == '__main__': - test_fortran_frontend_loop_region_basic_loop() diff --git a/tests/fortran/fortran_test_helper.py b/tests/fortran/fortran_test_helper.py new file mode 100644 index 0000000000..a82f392f35 --- /dev/null +++ b/tests/fortran/fortran_test_helper.py @@ -0,0 +1,293 @@ +import re +import subprocess +from dataclasses import dataclass, field +from os import path +from tempfile import TemporaryDirectory +from typing import Dict, Optional, Tuple, Type, Union, List, Sequence, Collection + +from fparser.two.Fortran2003 import Name + +from dace.frontend.fortran.ast_internal_classes import Name_Node +from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast, SDFGConfig, \ + create_sdfg_from_internal_ast + + +@dataclass +class SourceCodeBuilder: + """ + A helper class that helps to construct the source code structure for frontend tests. + + Example usage: + ```python + # Construct the builder, add files in the order you'd pass them to `gfortran`, (optional step) check if they all + # compile together, then get a dictionary mapping file names (possibly auto-inferred) to their content. + sources, main = SourceCodeBuilder().add_file(''' + module lib + end end module lib + ''').add_file(''' + program main + use lib + implicit none + end program main + ''').check_with_gfortran().get() + # Then construct the SDFG. + sdfg = create_sdfg_from_string(main, "main", sources=sources) + ``` + """ + sources: Dict[str, str] = field(default_factory=dict) + + def add_file(self, content: str, name: Optional[str] = None): + """Add source file contents in the order you'd pass them to `gfortran`.""" + if not name: + name = SourceCodeBuilder._identify_name(content) + key = f"{name}.f90" + assert key not in self.sources, f"{key} in {list(self.sources.keys())}: {self.sources[key]}" + self.sources[key] = content + return self + + def check_with_gfortran(self): + """Assert that it all compiles with `gfortran -Wall -c`.""" + with TemporaryDirectory() as td: + # Create temporary Fortran source-file structure. + for fname, content in self.sources.items(): + with open(path.join(td, fname), 'w') as f: + f.write(content) + # Run `gfortran -Wall` to verify that it compiles. + # Note: we're relying on the fact that python dictionaries keeps the insertion order when calling `keys()`. + cmd = ['gfortran', '-Wall', '-shared', '-fPIC', *self.sources.keys()] + + try: + subprocess.run(cmd, cwd=td, capture_output=True).check_returncode() + return self + except subprocess.CalledProcessError as e: + print("Fortran compilation failed!") + print(e.stderr.decode()) + raise e + + def get(self) -> Tuple[Dict[str, str], Optional[str]]: + """Get a dictionary mapping file names (possibly auto-inferred) to their content.""" + main = None + if 'main.f90' in self.sources: + main = self.sources['main.f90'] + return self.sources, main + + @staticmethod + def _identify_name(content: str) -> str: + PPAT = re.compile("^.*\\bprogram\\b\\s*\\b(?P[a-zA-Z0-9_]+)\\b.*$", re.I | re.M | re.S) + if PPAT.match(content): + return 'main' + MPAT = re.compile("^.*\\bmodule\\b\\s*\\b(?P[a-zA-Z0-9_]+)\\b.*$", re.I | re.M | re.S) + if MPAT.match(content): + match = MPAT.search(content) + return match.group('mod') + FPAT = re.compile("^.*\\bfunction\\b\\s*\\b(?P[a-zA-Z0-9_]+)\\b.*$", re.I | re.M | re.S) + if FPAT.match(content): + return 'main' + SPAT = re.compile("^.*\\bsubroutine\\b\\s*\\b(?P[a-zA-Z0-9_]+)\\b.*$", re.I | re.M | re.S) + if SPAT.match(content): + return 'main' + assert not any(PAT.match(content) for PAT in (PPAT, MPAT, FPAT, SPAT)) + + +class FortranASTMatcher: + """ + A "matcher" class that asserts if a given `node` has the right type, and its children, attributes etc. also matches + the submatchers. + + Example usage: + ```python + # Construct a matcher that looks for specific patterns in the AST structure, while ignoring unnecessary details. + m = M(Program, [ + M(Main_Program, [ + M.IGNORE(), # program main + M(Specification_Part), # implicit none; double precision d(4) + M(Execution_Part, [M(Call_Stmt)]), # call fun(d) + M.IGNORE(), # end program main + ]), + M(Subroutine_Subprogram, [ + M(Subroutine_Stmt), # subroutine fun(d) + M(Specification_Part, [ + M(Implicit_Part), # implicit none + M(Type_Declaration_Stmt), # double precision d(4) + ]), + M(Execution_Part, [M(Assignment_Stmt)]), # d(2) = 5.5 + M(End_Subroutine_Stmt), # end subroutine fun + ]), + ]) + # Check that a given Fortran AST matches that pattern. + m.check(ast) + ``` + """ + + def __init__(self, + is_type: Union[None, Type, str] = None, + has_children: Union[None, list] = None, + has_attr: Optional[Dict[str, Union["FortranASTMatcher", List["FortranASTMatcher"]]]] = None, + has_value: Optional[str] = None): + # TODO: Include Set[Self] to `has_children` type? + assert not ((set() if has_attr is None else has_attr.keys()) + & {'children'}) + self.is_type = is_type + self.has_children = has_children + self.has_attr = has_attr + self.has_value = has_value + + def check(self, node): + if self.is_type is not None: + if isinstance(self.is_type, type): + assert isinstance(node, self.is_type), \ + f"type mismatch at {node}; want: {self.is_type}, got: {type(node)}" + elif isinstance(self.is_type, str): + assert node.__class__.__name__ == self.is_type, \ + f"type mismatch at {node}; want: {self.is_type}, got: {type(node)}" + if self.has_value is not None: + assert node == self.has_value + if self.has_children is not None and len(self.has_children) > 0: + assert hasattr(node, 'children') + all_children = getattr(node, 'children') + assert len(self.has_children) == len(all_children), \ + f"#children mismatch at {node}; want: {len(self.has_children)}, got: {len(all_children)}" + for (c, m) in zip(all_children, self.has_children): + m.check(c) + if self.has_attr is not None and len(self.has_attr.keys()) > 0: + for key, subm in self.has_attr.items(): + assert hasattr(node, key) + attr = getattr(node, key) + + if isinstance(subm, Sequence): + assert isinstance(attr, Sequence) + assert len(attr) == len(subm) + for (c, m) in zip(attr, subm): + m.check(c) + else: + subm.check(attr) + + @classmethod + def IGNORE(cls, times: Optional[int] = None) -> Union["FortranASTMatcher", List["FortranASTMatcher"]]: + """ + A placeholder matcher to not check further down the tree. + If `times` is `None` (which is the default), returns a single matcher. + If `times` is an integer value, then returns a list of `IGNORE()` matchers of that size, indicating that many + nodes on a row should be ignored. + """ + if times is None: + return cls() + else: + return [cls()] * times + + @classmethod + def NAMED(cls, name: str): + return cls(Name, has_attr={'string': cls(has_value=name)}) + + +class InternalASTMatcher: + """ + A "matcher" class that asserts if a given `node` has the right type, and its children, attributes etc. also matches + the submatchers. + + Example usage: + ```python + # Construct a matcher that looks for specific patterns in the AST structure, while ignoring unnecessary details. + m = M(Program_Node, { + 'main_program': M(Main_Program_Node, { + 'name': M(Program_Stmt_Node), + 'specification_part': M(Specification_Part_Node, { + 'specifications': [ + M(Decl_Stmt_Node, { + 'vardecl': [M(Var_Decl_Node)], + }) + ], + }, {'interface_blocks', 'symbols', 'typedecls', 'uses'}), + 'execution_part': M(Execution_Part_Node, { + 'execution': [ + M(Call_Expr_Node, { + 'name': M(Name_Node), + 'args': [M(Name_Node, { + 'name': M(has_value='d'), + 'type': M(has_value='DOUBLE'), + })], + 'type': M(has_value='VOID'), + }) + ], + }), + }, {'parent'}), + 'structures': M(Structures, None, {'structures'}), + }, {'function_definitions', 'module_declarations', 'modules'}) + # Check that a given internal AST matches that pattern. + m.check(prog) + ``` + """ + + def __init__(self, + is_type: Optional[Type] = None, + has_attr: Optional[Dict[str, Union["InternalASTMatcher", List["InternalASTMatcher"], Dict[str, "InternalASTMatcher"]]]] = None, + has_empty_attr: Optional[Collection[str]] = None, + has_value: Optional[str] = None): + # TODO: Include Set[Self] to `has_children` type? + assert not ((set() if has_attr is None else has_attr.keys()) + & (set() if has_empty_attr is None else has_empty_attr)) + self.is_type: Type = is_type + self.has_attr = has_attr + self.has_empty_attr = has_empty_attr + self.has_value = has_value + + def check(self, node): + if self.is_type is not None: + assert isinstance(node, self.is_type) + if self.has_value is not None: + assert node == self.has_value + if self.has_empty_attr is not None: + for key in self.has_empty_attr: + assert not hasattr(node, key) or not getattr(node, key), f"{node} is expected to not have key: {key}" + if self.has_attr is not None and len(self.has_attr.keys()) > 0: + for key, subm in self.has_attr.items(): + assert hasattr(node, key), f"{node} doesn't have key: {key}" + attr = getattr(node, key) + + if isinstance(subm, Sequence): + assert isinstance(attr, Sequence), f"{attr} must be a sequence, since {subm} is." + assert len(attr) == len(subm), f"{attr} must have the same length as {subm}." + for (c, m) in zip(attr, subm): + m.check(c) + elif isinstance(subm, Dict): + assert isinstance(attr, Dict) + assert len(attr) == len(subm) + assert subm.keys() <= attr.keys() + for k in subm.keys(): + subm[k].check(attr[k]) + else: + subm.check(attr) + + @classmethod + def IGNORE(cls, times: Optional[int] = None) -> Union["InternalASTMatcher", List["InternalASTMatcher"]]: + """ + A placeholder matcher to not check further down the tree. + If `times` is `None` (which is the default), returns a single matcher. + If `times` is an integer value, then returns a list of `IGNORE()` matchers of that size, indicating that many + nodes on a row should be ignored. + """ + if times is None: + return cls() + else: + return [cls()] * times + + @classmethod + def NAMED(cls, name: str): + return cls(Name_Node, {'name': cls(has_value=name)}) + + +def create_singular_sdfg_from_string( + sources: Dict[str, str], + entry_point: str, + normalize_offsets: bool = True): + entry_point = entry_point.split('.') + + cfg = ParseConfig(main=sources['main.f90'], sources=sources, entry_points=tuple(entry_point)) + own_ast, program = create_internal_ast(cfg) + + cfg = SDFGConfig({entry_point[-1]: entry_point}, normalize_offsets, False) + gmap = create_sdfg_from_internal_ast(own_ast, program, cfg) + assert gmap.keys() == {entry_point[-1]} + g = list(gmap.values())[0] + + return g diff --git a/tests/fortran/future/fortran_class_test.py b/tests/fortran/future/fortran_class_test.py new file mode 100644 index 0000000000..7e6ab50577 --- /dev/null +++ b/tests/fortran/future/fortran_class_test.py @@ -0,0 +1,117 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + + + +def test_fortran_frontend_class(): + """ + Tests that whether clasess are translated correctly + """ + test_string = """ + PROGRAM class_test + + TYPE, ABSTRACT :: t_comm_pattern + + CONTAINS + + PROCEDURE(interface_setup_comm_pattern), DEFERRED :: setup + PROCEDURE(interface_exchange_data_r3d), DEFERRED :: exchange_data_r3d +END TYPE t_comm_pattern + +TYPE, EXTENDS(t_comm_pattern) :: t_comm_pattern_orig + INTEGER :: n_pnts ! Number of points we output into local array; + ! this may be bigger than n_recv due to + ! duplicate entries + + INTEGER, ALLOCATABLE :: recv_limits(:) + + CONTAINS + + PROCEDURE :: setup => setup_comm_pattern + PROCEDURE :: exchange_data_r3d => exchange_data_r3d + +END TYPE t_comm_pattern_orig + + + + implicit none + integer d(2) + CALL class_test_function(d) + end + + +SUBROUTINE setup_comm_pattern(p_pat, dst_n_points) + + CLASS(t_comm_pattern_orig), TARGET, INTENT(OUT) :: p_pat + + INTEGER, INTENT(IN) :: dst_n_points ! Total number of points + + p_pat%n_pnts = dst_n_points + END SUBROUTINE setup_comm_pattern + + SUBROUTINE exchange_data_r3d(p_pat, recv) + + CLASS(t_comm_pattern_orig), TARGET, INTENT(INOUT) :: p_pat + REAL, INTENT(INOUT), TARGET :: recv(:,:,:) + + recv(1,1,1)=recv(1,1,1)+p_pat%n_pnts + + END SUBROUTINE exchange_data_r3d + + SUBROUTINE class_test_function(d) + integer d(2) + real recv(2,2,2) + + CLASS(t_comm_pattern_orig) :: p_pat + + CALL setup_comm_pattern(p_pat, 42) + CALL exchange_data_r3d(p_pat, recv) + d(1)=p_pat%n_pnts + END SUBROUTINE class_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "class_test",False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + sdfg.compile() + # sdfg = fortran_parser.create_sdfg_from_string(test_string, "int_init_test") + # sdfg.simplify(verbose=True) + # d = np.full([2], 42, order="F", dtype=np.int64) + # sdfg(d=d) + # assert (d[0] == 400) + + + +if __name__ == "__main__": + + + + test_fortran_frontend_class() + diff --git a/tests/fortran/global_test.py b/tests/fortran/global_test.py new file mode 100644 index 0000000000..82dcc46db0 --- /dev/null +++ b/tests/fortran/global_test.py @@ -0,0 +1,124 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_global(): + """ + Tests that the Fortran frontend can parse complex global includes. + """ + test_string = """ + PROGRAM global_test + implicit none + USE global_test_module_subroutine, ONLY: global_test_function + + REAL :: d(4), a(4,4,4) + + CALL global_test_function(d) + + end + + + """ + sources={} + sources["global_test"]=test_string + sources["global_test_module_subroutine.f90"]=""" + MODULE global_test_module_subroutine + + CONTAINS + + SUBROUTINE global_test_function(d) + USE global_test_module, ONLY: outside_init,simple_type + USE nested_one, ONLY: nested + double precision d(4) + double precision :: a(4,4,4) + integer :: i + + + + + TYPE(simple_type) :: ptr_patch + + double precision d(4) + ptr_patch%w(:,:,:)=5.5 + + i=outside_init + CALL nested(i,ptr_patch%w) + d(i+1)=5.5+ptr_patch%w(3,3,3) + + END SUBROUTINE global_test_function + END MODULE global_test_module_subroutine + """ + sources["global_test_module.f90"]=""" + MODULE global_test_module + IMPLICIT NONE + TYPE simple_type + double precision,POINTER :: w(:,:,:) + integer a + + END TYPE simple_type + integer outside_init=1 + END MODULE global_test_module + """ + + sources["nested_one.f90"]=""" + MODULE nested_one + IMPLICIT NONE + CONTAINS + SUBROUTINE nested(i,a) + USE nested_two, ONLY: nestedtwo + integer :: i + double precision :: a(:,:,:) + + i=0 + CALL nestedtwo(i) + a(i+1,i+1,i+1)=5.5 + END SUBROUTINE nested + + END MODULE nested_one + """ + sources["nested_two.f90"]=""" + MODULE nested_two + IMPLICIT NONE + CONTAINS + SUBROUTINE nestedtwo(i) + USE global_test_module, ONLY: outside_init + integer :: i + + i = outside_init+1 + + END SUBROUTINE nestedtwo + + END MODULE nested_two + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "global_test",sources=sources,normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.save('test.sdfg') + a = np.full([4], 42, order="F", dtype=np.float64) + a2 = np.full([4,4,4], 42, order="F", dtype=np.float64) + #TODO Add validation - but we need python structs for this. + #sdfg(d=a,a=a2) + #assert (a[0] == 42) + #assert (a[1] == 5.5) + #assert (a[2] == 42) + +if __name__ == "__main__": + + test_fortran_frontend_global() diff --git a/tests/fortran/ifcycle_test.py b/tests/fortran/ifcycle_test.py new file mode 100644 index 0000000000..ae7a943721 --- /dev/null +++ b/tests/fortran/ifcycle_test.py @@ -0,0 +1,107 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_if_cycle(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM if_cycle_test + implicit none + double precision :: d(4) + CALL if_cycle_test_function(d) + end + + SUBROUTINE if_cycle_test_function(d) + double precision d(4,5) + integer :: i + DO i=1,4 + if (i .eq. 2) CYCLE + d(i)=5.5 + ENDDO + if (d(2) .eq. 42) d(2)=6.5 + + + END SUBROUTINE if_cycle_test_function + """ + sources={} + sources["if_cycle"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_cycle",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0] == 5.5) + assert (a[1] == 6.5) + assert (a[2] == 5.5) + + + +def test_fortran_frontend_if_nested_cycle(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM if_nested_cycle_test + implicit none + double precision :: d(4,4) + + CALL if_nested_cycle_test_function(d) + end + + SUBROUTINE if_nested_cycle_test_function(d) + double precision d(4,4) + double precision :: tmp + integer :: i,limit,start,count + limit=4 + start=1 + DO i=start,limit + count=0 + + DO j=start,limit + if (j .eq. 2) count=count+2 + ENDDO + if (count .eq. 2) CYCLE + if (count .eq. 3) CYCLE + DO j=start,limit + + d(i,j)=d(i,j)+1.5 + ENDDO + d(i,1)=5.5 + ENDDO + + if (d(2,1) .eq. 42.0) d(2,1)=6.5 + + + END SUBROUTINE if_nested_cycle_test_function + """ + sources={} + sources["if_nested_cycle"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_nested_cycle",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4,4], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0,0] == 42) + assert (a[1,0] == 6.5) + assert (a[2,0] == 42) + + +if __name__ == "__main__": + + test_fortran_frontend_if_nested_cycle() diff --git a/tests/fortran/init_test.py b/tests/fortran/init_test.py new file mode 100644 index 0000000000..f23446bf14 --- /dev/null +++ b/tests/fortran/init_test.py @@ -0,0 +1,113 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_init(): + """ + Tests that the Fortran frontend can parse complex initializations. + """ + test_string = """ + PROGRAM init_test + implicit none + USE init_test_module_subroutine, ONLY: init_test_function + double precision d(4) + CALL init_test_function(d) + end + + + """ + sources={} + sources["init_test"]=test_string + sources["init_test_module_subroutine.f90"]=""" + MODULE init_test_module_subroutine + CONTAINS + SUBROUTINE init_test_function(d) + USE init_test_module, ONLY: outside_init + double precision d(4) + REAL bob=EPSILON(1.0) + + + d(2)=5.5 +bob +outside_init + + END SUBROUTINE init_test_function + END MODULE init_test_module_subroutine + """ + sources["init_test_module.f90"]=""" + MODULE init_test_module + IMPLICIT NONE + REAL outside_init=EPSILON(1.0) + END MODULE init_test_module + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a,outside_init=0) + assert (a[0] == 42) + assert (a[1] == 5.5) + assert (a[2] == 42) + + +def test_fortran_frontend_init2(): + """ + Tests that the Fortran frontend can parse complex initializations. + """ + test_string = """ + PROGRAM init2_test + implicit none + USE init2_test_module_subroutine, ONLY: init2_test_function + double precision d(4) + CALL init2_test_function(d) + end + + + """ + sources={} + sources["init2_test"]=test_string + sources["init2_test_module_subroutine.f90"]=""" + MODULE init2_test_module_subroutine + CONTAINS + SUBROUTINE init2_test_function(d) + USE init2_test_module, ONLY: TORUS_MAX_LAT + double precision d(4) + + + d(2)=5.5 + TORUS_MAX_LAT + + END SUBROUTINE init2_test_function + END MODULE init2_test_module_subroutine + """ + sources["init2_test_module.f90"]=""" + MODULE init2_test_module + IMPLICIT NONE + REAL, PARAMETER :: TORUS_MAX_LAT = 4.0 / 18.0 * ATAN(1.0) + END MODULE init2_test_module + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init2_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a, torus_max_lat=4.0 / 18.0 * np.arctan(1.0)) + assert (a[0] == 42) + assert (a[1] == 5.674532920122147) + assert (a[2] == 42) + +if __name__ == "__main__": + + test_fortran_frontend_init() + test_fortran_frontend_init2() diff --git a/tests/fortran/intrinsic_all_test.py b/tests/fortran/intrinsic_all_test.py index 4a368aff2c..dc0c76c677 100644 --- a/tests/fortran/intrinsic_all_test.py +++ b/tests/fortran/intrinsic_all_test.py @@ -24,7 +24,7 @@ def test_fortran_frontend_all_array(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -60,7 +60,7 @@ def test_fortran_frontend_all_array_dim(): """ with pytest.raises(NotImplementedError): - fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") def test_fortran_frontend_all_array_comparison(): @@ -91,7 +91,7 @@ def test_fortran_frontend_all_array_comparison(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -134,7 +134,7 @@ def test_fortran_frontend_all_array_scalar_comparison(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -181,7 +181,7 @@ def test_fortran_frontend_all_array_comparison_wrong_subset(): """ with pytest.raises(TypeError): - fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") def test_fortran_frontend_all_array_2d(): test_string = """ @@ -201,7 +201,7 @@ def test_fortran_frontend_all_array_2d(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -244,7 +244,7 @@ def test_fortran_frontend_all_array_comparison_2d(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -287,7 +287,7 @@ def test_fortran_frontend_all_array_comparison_2d_subset(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -329,7 +329,7 @@ def test_fortran_frontend_all_array_comparison_2d_subset_offset(): END SUBROUTINE intrinsic_all_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") sdfg.simplify(verbose=True) sdfg.compile() diff --git a/tests/fortran/intrinsic_any_test.py b/tests/fortran/intrinsic_any_test.py index c1d82cd2e0..49d0b5c12c 100644 --- a/tests/fortran/intrinsic_any_test.py +++ b/tests/fortran/intrinsic_any_test.py @@ -91,7 +91,7 @@ def test_fortran_frontend_any_array_comparison(): END SUBROUTINE intrinsic_any_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -136,7 +136,7 @@ def test_fortran_frontend_any_array_scalar_comparison(): END SUBROUTINE intrinsic_any_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -247,7 +247,7 @@ def test_fortran_frontend_any_array_comparison_2d(): END SUBROUTINE intrinsic_any_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -290,7 +290,7 @@ def test_fortran_frontend_any_array_comparison_2d_subset(): END SUBROUTINE intrinsic_any_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test") sdfg.simplify(verbose=True) sdfg.compile() diff --git a/tests/fortran/intrinsic_basic_test.py b/tests/fortran/intrinsic_basic_test.py index 9ef31dd108..4a2e10d8d6 100644 --- a/tests/fortran/intrinsic_basic_test.py +++ b/tests/fortran/intrinsic_basic_test.py @@ -4,6 +4,7 @@ import pytest from dace.frontend.fortran import fortran_parser +from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder def test_fortran_frontend_bit_size(): test_string = """ @@ -42,9 +43,9 @@ def test_fortran_frontend_bit_size_symbolic(): test_string = """ PROGRAM intrinsic_math_test_bit_size implicit none - integer, parameter :: arrsize = 2 - integer, parameter :: arrsize2 = 3 - integer, parameter :: arrsize3 = 4 + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 integer :: res(arrsize) integer :: res2(arrsize, arrsize2, arrsize3) integer :: res3(arrsize+arrsize2, arrsize2 * 5, arrsize3 + arrsize2*arrsize) @@ -82,7 +83,6 @@ def test_fortran_frontend_bit_size_symbolic(): res2 = np.full([size, size2, size3], 42, order="F", dtype=np.int32) res3 = np.full([size+size2, size2*5, size3 + size*size2], 42, order="F", dtype=np.int32) sdfg(res=res, res2=res2, res3=res3, arrsize=size, arrsize2=size2, arrsize3=size3) - print(res) assert res[0] == size assert res[1] == size*size2*size3 @@ -92,7 +92,291 @@ def test_fortran_frontend_bit_size_symbolic(): assert res[5] == size + size2 + size3 assert res[6] == size + size2 + size2*5 + size3 + size*size2 +def test_fortran_frontend_size_arbitrary(): + test_string = """ + PROGRAM intrinsic_basic_size_arbitrary + implicit none + integer :: arrsize + integer :: arrsize2 + integer :: res(arrsize, arrsize2) + CALL intrinsic_basic_size_arbitrary_test_function(res) + end + + SUBROUTINE intrinsic_basic_size_arbitrary_test_function(res) + implicit none + integer :: res(:, :) + + res(1,1) = SIZE(res) + res(2,1) = SIZE(res, 1) + res(3,1) = SIZE(res, 2) + + END SUBROUTINE intrinsic_basic_size_arbitrary_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_basic_size_arbitrary_test", True,) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + size2 = 5 + res = np.full([size, size2], 42, order="F", dtype=np.int32) + sdfg(res=res,arrsize=size,arrsize2=size2) + + assert res[0,0] == size*size2 + assert res[1,0] == size + assert res[2,0] == size2 + +def test_fortran_frontend_present(): + test_string = """ + PROGRAM intrinsic_basic_present + implicit none + integer, dimension(4) :: res + integer, dimension(4) :: res2 + integer :: a + CALL intrinsic_basic_present_test_function(res, res2, a) + end + + SUBROUTINE intrinsic_basic_present_test_function(res, res2, a) + integer, dimension(4) :: res + integer, dimension(4) :: res2 + integer :: a + + CALL tf2(res, a=a) + CALL tf2(res2) + + END SUBROUTINE intrinsic_basic_present_test_function + + SUBROUTINE tf2(res, a) + integer, dimension(4) :: res + integer, optional :: a + + res(1) = PRESENT(a) + + END SUBROUTINE tf2 + """ + #test_string = """ + # PROGRAM intrinsic_basic_present + # implicit none + # integer, dimension(4) :: res + # integer, dimension(4) :: res2 + # integer :: a + # CALL test_intrinsic_basic_pre2sent_function(res, res2,a) + # end + + # SUBROUTINE test_intrinsic_basic_pre2sent_function(res, res2,a) + # integer, dimension(4) :: res + # integer, dimension(4) :: res2 + # integer :: a + + # res(1) = 1 + # END SUBROUTINE test_intrinsic_basic_pre2sent_function + # !PROGRAM intrinsic_basic_present + # !implicit none + # !integer, dimension(4) :: res + # !integer, dimension(4) :: res2 + # !integer :: a + # !CALL intrinsic_basic_present_function(res, res2, a) + # !end + + # !SUBROUTINE intrinsic_basic_present_function(res, res2, a) + # !integer, dimension(4) :: res + # !integer, dimension(4) :: res2 + # !integer :: a + + # !res(1) = 5 + # !!CALL intrinsic_basic_present_function2(res, a) + # !!CALL intrinsic_basic_present_function2(res2) + + # !END SUBROUTINE intrinsic_basic_present_function + + # !SUBROUTINE intrinsic_basic_present_function2(res, a) + # !integer, dimension(4) :: res + # !integer :: a + + # !!res(1) = PRESENT(a) + # !res(1) = 2 + # !res(2) = 2 + + # !END SUBROUTINE intrinsic_basic_present_function2 + # """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_basic_present_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + res2 = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res, res2=res2, a=5) + + assert res[0] == 1 + assert res2[0] == 0 + +def test_fortran_frontend_bitwise_ops(): + sources, main = SourceCodeBuilder().add_file(""" + SUBROUTINE bitwise_ops(input, res) + + integer, dimension(11) :: input + integer, dimension(11) :: res + + res(1) = IBSET(input(1), 0) + res(2) = IBSET(input(2), 30) + + res(3) = IBCLR(input(3), 0) + res(4) = IBCLR(input(4), 30) + + res(5) = IEOR(input(5), 63) + res(6) = IEOR(input(6), 480) + + res(7) = ISHFT(input(7), 5) + res(8) = ISHFT(input(8), 30) + + res(9) = ISHFT(input(9), -5) + res(10) = ISHFT(input(10), -30) + + res(11) = ISHFT(input(11), 0) + + END SUBROUTINE bitwise_ops +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'bitwise_ops', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 11 + input = np.full([size], 42, order="F", dtype=np.int32) + res = np.full([size], 42, order="F", dtype=np.int32) + + input = [32, 32, 33, 1073741825, 53, 530, 12, 1, 128, 1073741824, 12 ] + + sdfg(input=input, res=res) + + assert np.allclose(res, [33, 1073741856, 32, 1, 10, 1010, 384, 1073741824, 4, 1, 12]) + +def test_fortran_frontend_bitwise_ops2(): + sources, main = SourceCodeBuilder().add_file(""" + SUBROUTINE bitwise_ops(input, res) + + integer, dimension(6) :: input + integer, dimension(6) :: res + + res(1) = IAND(input(1), 0) + res(2) = IAND(input(2), 31) + + res(3) = BTEST(input(3), 0) + res(4) = BTEST(input(4), 5) + + res(5) = IBITS(input(5), 0, 5) + res(6) = IBITS(input(6), 3, 10) + + END SUBROUTINE bitwise_ops +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'bitwise_ops', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 6 + input = np.full([size], 42, order="F", dtype=np.int32) + res = np.full([size], 42, order="F", dtype=np.int32) + + input = [2147483647, 16, 3, 31, 30, 630] + + sdfg(input=input, res=res) + + assert np.allclose(res, [0, 16, 1, 0, 30, 78]) + +def test_fortran_frontend_allocated(): + # FIXME: this pattern is generally not supported. + # this needs an update once defered allocs are merged + + sources, main = SourceCodeBuilder().add_file(""" + SUBROUTINE allocated_test(res) + + integer, allocatable, dimension(:) :: data + integer, dimension(3) :: res + + res(1) = ALLOCATED(data) + + ALLOCATE(data(6)) + + res(2) = ALLOCATED(data) + + DEALLOCATE(data) + + res(3) = ALLOCATED(data) + + END SUBROUTINE allocated_test +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'allocated_test', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 3 + res = np.full([size], 42, order="F", dtype=np.int32) + + sdfg(res=res) + + assert np.allclose(res, [0, 1, 0]) + +def test_fortran_frontend_allocated_nested(): + + # FIXME: this pattern is generally not supported. + # this needs an update once defered allocs are merged + + sources, main = SourceCodeBuilder().add_file(""" + MODULE allocated_test_interface + INTERFACE + SUBROUTINE allocated_test_nested(data, res) + integer, allocatable, dimension(:) :: data + integer, dimension(3) :: res + END SUBROUTINE allocated_test_nested + END INTERFACE + END MODULE + + SUBROUTINE allocated_test(res) + USE allocated_test_interface + implicit none + integer, allocatable, dimension(:) :: data + integer, dimension(3) :: res + + res(1) = ALLOCATED(data) + + ALLOCATE(data(6)) + + CALL allocated_test_nested(data, res) + + END SUBROUTINE allocated_test + + SUBROUTINE allocated_test_nested(data, res) + + integer, allocatable, dimension(:) :: data + integer, dimension(3) :: res + + res(2) = ALLOCATED(data) + + DEALLOCATE(data) + + res(3) = ALLOCATED(data) + + END SUBROUTINE allocated_test_nested +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'allocated_test', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 3 + res = np.full([size], 42, order="F", dtype=np.int32) + + sdfg(res=res, __f2dace_A_data_d_0_s_0=0) + + assert np.allclose(res, [0, 1, 0]) if __name__ == "__main__": + test_fortran_frontend_bit_size() test_fortran_frontend_bit_size_symbolic() + test_fortran_frontend_size_arbitrary() + test_fortran_frontend_present() + test_fortran_frontend_bitwise_ops() + test_fortran_frontend_bitwise_ops2() + test_fortran_frontend_allocated() + test_fortran_frontend_allocated_nested() diff --git a/tests/fortran/intrinsic_blas_test.py b/tests/fortran/intrinsic_blas_test.py new file mode 100644 index 0000000000..2a04c7e1f3 --- /dev/null +++ b/tests/fortran/intrinsic_blas_test.py @@ -0,0 +1,174 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder + + +def test_fortran_frontend_dot(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(arg1, arg2, res1) + double precision, dimension(5) :: arg1 + double precision, dimension(5) :: arg2 + double precision, dimension(2) :: res1 + res1(1) = dot_product(arg1, arg2) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + # TODO: We should re-enable `simplify()` once we merge it. + # sdfg.simplify() + sdfg.compile() + + size = 5 + arg1 = np.full([size], 42, order="F", dtype=np.float64) + arg2 = np.full([size], 42, order="F", dtype=np.float64) + res1 = np.full([2], 0, order="F", dtype=np.float64) + + for i in range(size): + arg1[i] = i + 1 + arg2[i] = i + 5 + + sdfg(arg1=arg1, arg2=arg2, res1=res1) + + assert res1[0] == np.dot(arg1, arg2) + + +def test_fortran_frontend_dot_range(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(arg1, arg2, res1) + double precision, dimension(5) :: arg1 + double precision, dimension(5) :: arg2 + double precision, dimension(2) :: res1 + res1(1) = dot_product(arg1(1:3), arg2(1:3)) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + # TODO: We should re-enable `simplify()` once we merge it. + # sdfg.simplify() + sdfg.compile() + + size = 5 + arg1 = np.full([size], 42, order="F", dtype=np.float64) + arg2 = np.full([size], 42, order="F", dtype=np.float64) + res1 = np.full([2], 0, order="F", dtype=np.float64) + + for i in range(size): + arg1[i] = i + 1 + arg2[i] = i + 5 + + sdfg(arg1=arg1, arg2=arg2, res1=res1) + assert res1[0] == np.dot(arg1, arg2) + +def test_fortran_frontend_transpose(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(arg1, arg2, res1) + double precision, dimension(5,4) :: arg1 + double precision, dimension(4,5) :: res1 + res1 = transpose(arg1) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + # TODO: We should re-enable `simplify()` once we merge it. + # sdfg.simplify() + sdfg.compile() + + size_x = 5 + size_y = 4 + arg1 = np.full([size_x, size_y], 42, order="F", dtype=np.float64) + res1 = np.full([size_y, size_x], 42, order="F", dtype=np.float64) + + for i in range(size_x): + for j in range(size_y): + arg1[i, j] = i + 1 + + sdfg(arg1=arg1, res1=res1) + + assert np.all(np.transpose(res1) == arg1) + +def test_fortran_frontend_transpose_struct(): + sources, main = SourceCodeBuilder().add_file(""" + +MODULE test_types + IMPLICIT NONE + TYPE array_container + double precision, dimension(5,4) :: arg1 + END TYPE array_container +END MODULE + +MODULE test_transpose + + contains + + subroutine test_function(arg1, res1) + USE test_types + IMPLICIT NONE + TYPE(array_container) :: container + double precision, dimension(5,4) :: arg1 + double precision, dimension(4,5) :: res1 + + container%arg1 = arg1 + + res1 = transpose(container%arg1) + end subroutine test_function + +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_transpose.test_function', normalize_offsets=True) + # TODO: We should re-enable `simplify()` once we merge it. + sdfg.simplify() + sdfg.compile() + sdfg.save('test.sdfg') + + size_x = 5 + size_y = 4 + arg1 = np.full([size_x, size_y], 42, order="F", dtype=np.float64) + res1 = np.full([size_y, size_x], 42, order="F", dtype=np.float64) + + for i in range(size_x): + for j in range(size_y): + arg1[i, j] = i + 1 + + sdfg(arg1=arg1, res1=res1) + print(arg1) + print(res1) + + assert np.all(np.transpose(res1) == arg1) + +def test_fortran_frontend_matmul(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(arg1, arg2, res1) + double precision, dimension(5,3) :: arg1 + double precision, dimension(3,7) :: arg2 + double precision, dimension(5,7) :: res1 + res1 = matmul(arg1, arg2) +end subroutine main +""").check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'main', normalize_offsets=False) + # TODO: We should re-enable `simplify()` once we merge it. + # sdfg.simplify() + sdfg.compile() + + size_x = 5 + size_y = 3 + size_z = 7 + arg1 = np.full([size_x, size_y], 42, order="F", dtype=np.float64) + arg2 = np.full([size_y, size_z], 42, order="F", dtype=np.float64) + res1 = np.full([size_x, size_z], 42, order="F", dtype=np.float64) + + for i in range(size_x): + for j in range(size_y): + arg1[i, j] = i + j + 1 + for i in range(size_y): + for j in range(size_z): + arg2[i, j] = i + j + 7 + + sdfg(arg1=arg1, arg2=arg2, res1=res1) + + assert np.all(np.matmul(arg1, arg2) == res1) + +if __name__ == "__main__": + #test_fortran_frontend_dot() + #test_fortran_frontend_dot_range() + #test_fortran_frontend_transpose() + test_fortran_frontend_transpose_struct() + #test_fortran_frontend_matmul() diff --git a/tests/fortran/intrinsic_bound_test.py b/tests/fortran/intrinsic_bound_test.py new file mode 100644 index 0000000000..af77aba186 --- /dev/null +++ b/tests/fortran/intrinsic_bound_test.py @@ -0,0 +1,429 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser +from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string + +""" + Test the implementation of LBOUND/UBOUND functions. + * Standard-sized arrays. + * Standard-sized arrays with offsets. + * Arrays with assumed shape. + * Arrays with assumed shape - passed externally. + * Arrays with assumed shape with offsets. + * Arrays inside structures. + * Arrays inside structures with local override. + * Arrays inside structures with multiple layers of indirection. + * Arrays inside structures with multiple layers of indirection + assumed size. +""" + +def test_fortran_frontend_bound(): + test_string = """ + PROGRAM intrinsic_bound_test + implicit none + integer, dimension(4,7) :: input + integer, dimension(4) :: res + CALL intrinsic_bound_test_function(res) + end + + SUBROUTINE intrinsic_bound_test_function(res) + integer, dimension(4,7) :: input + integer, dimension(4) :: res + + res(1) = LBOUND(input, 1) + res(2) = LBOUND(input, 2) + res(3) = UBOUND(input, 1) + res(4) = UBOUND(input, 2) + + END SUBROUTINE intrinsic_bound_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_bound_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [1, 1, 4, 7]) + +def test_fortran_frontend_bound_offsets(): + test_string = """ + PROGRAM intrinsic_bound_test + implicit none + integer, dimension(3:8, 9:12) :: input + integer, dimension(4) :: res + CALL intrinsic_bound_test_function(res) + end + + SUBROUTINE intrinsic_bound_test_function(res) + integer, dimension(3:8, 9:12) :: input + integer, dimension(4) :: res + + res(1) = LBOUND(input, 1) + res(2) = LBOUND(input, 2) + res(3) = UBOUND(input, 1) + res(4) = UBOUND(input, 2) + + END SUBROUTINE intrinsic_bound_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_bound_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [3, 9, 8, 12]) + +def test_fortran_frontend_bound_assumed(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE intrinsic_bound_interfaces + INTERFACE + SUBROUTINE intrinsic_bound_test_function2(input, res) + integer, dimension(:,:) :: input + integer, dimension(4) :: res + END SUBROUTINE intrinsic_bound_test_function2 + END INTERFACE +END MODULE + +SUBROUTINE intrinsic_bound_test_function(res) +USE intrinsic_bound_interfaces +implicit none +integer, dimension(4,7) :: input +integer, dimension(4) :: res + +CALL intrinsic_bound_test_function2(input, res) + +END SUBROUTINE intrinsic_bound_test_function + +SUBROUTINE intrinsic_bound_test_function2(input, res) +integer, dimension(:,:) :: input +integer, dimension(4) :: res + +res(1) = LBOUND(input, 1) +res(2) = LBOUND(input, 2) +res(3) = UBOUND(input, 1) +res(4) = UBOUND(input, 2) + +END SUBROUTINE intrinsic_bound_test_function2 +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [1, 1, 4, 7]) + +def test_fortran_frontend_bound_assumed_offsets(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE intrinsic_bound_interfaces + INTERFACE + SUBROUTINE intrinsic_bound_test_function2(input, res) + integer, dimension(:,:) :: input + integer, dimension(4) :: res + END SUBROUTINE intrinsic_bound_test_function2 + END INTERFACE +END MODULE + +SUBROUTINE intrinsic_bound_test_function(res) +USE intrinsic_bound_interfaces +implicit none +integer, dimension(42:45,13:19) :: input +integer, dimension(4) :: res + +CALL intrinsic_bound_test_function2(input, res) + +END SUBROUTINE intrinsic_bound_test_function + +SUBROUTINE intrinsic_bound_test_function2(input, res) +integer, dimension(:,:) :: input +integer, dimension(4) :: res + +res(1) = LBOUND(input, 1) +res(2) = LBOUND(input, 2) +res(3) = UBOUND(input, 1) +res(4) = UBOUND(input, 2) + +END SUBROUTINE intrinsic_bound_test_function2 +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [1, 1, 4, 7]) + +def test_fortran_frontend_bound_allocatable_offsets(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE intrinsic_bound_interfaces + INTERFACE + SUBROUTINE intrinsic_bound_test_function3(input, res) + integer, allocatable, dimension(:,:) :: input + integer, dimension(4) :: res + END SUBROUTINE intrinsic_bound_test_function3 + END INTERFACE +END MODULE + +SUBROUTINE intrinsic_bound_test_function(res) +USE intrinsic_bound_interfaces +implicit none +integer, allocatable, dimension(:,:) :: input +integer, dimension(4) :: res + +allocate(input(42:45, 13:19)) +CALL intrinsic_bound_test_function3(input, res) +deallocate(input) + +END SUBROUTINE intrinsic_bound_test_function + +SUBROUTINE intrinsic_bound_test_function3(input, res) +integer, allocatable, dimension(:,:) :: input +integer, dimension(4) :: res + +res(1) = LBOUND(input, 1) +res(2) = LBOUND(input, 2) +res(3) = UBOUND(input, 1) +res(4) = UBOUND(input, 2) + +END SUBROUTINE intrinsic_bound_test_function3 +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg( + res=res, + __f2dace_A_input_d_0_s_0=4, + __f2dace_A_input_d_1_s_1=7, + __f2dace_OA_input_d_0_s_0=42, + __f2dace_OA_input_d_1_s_1=13 + ) + + assert np.allclose(res, [42, 13, 45, 19]) + +def test_fortran_frontend_bound_structure(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE test_types + IMPLICIT NONE + TYPE array_container + INTEGER, DIMENSION(2:5, 3:9) :: data + END TYPE array_container +END MODULE + +MODULE test_bounds + USE test_types + IMPLICIT NONE + + CONTAINS + + SUBROUTINE intrinsic_bound_test_function( res) + TYPE(array_container) :: container + INTEGER, DIMENSION(4) :: res + + CALL intrinsic_bound_test_function_impl(container, res) + END SUBROUTINE + + SUBROUTINE intrinsic_bound_test_function_impl(container, res) + TYPE(array_container), INTENT(IN) :: container + INTEGER, DIMENSION(4) :: res + + res(1) = LBOUND(container%data, 1) ! Should return 2 + res(2) = LBOUND(container%data, 2) ! Should return 3 + res(3) = UBOUND(container%data, 1) ! Should return 5 + res(4) = UBOUND(container%data, 2) ! Should return 9 + END SUBROUTINE +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [2, 3, 5, 9]) + +def test_fortran_frontend_bound_structure_override(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE test_types + IMPLICIT NONE + TYPE array_container + INTEGER, DIMENSION(2:5, 3:9) :: data + END TYPE array_container +END MODULE + +MODULE test_bounds + USE test_types + IMPLICIT NONE + + CONTAINS + + SUBROUTINE intrinsic_bound_test_function( res) + TYPE(array_container) :: container + INTEGER, DIMENSION(4) :: res + + CALL intrinsic_bound_test_function_impl(container, res) + END SUBROUTINE + + SUBROUTINE intrinsic_bound_test_function_impl(container, res) + TYPE(array_container), INTENT(IN) :: container + INTEGER, DIMENSION(4) :: res + ! if we handle the refs correctly, this override won't fool us + integer, dimension(3, 10) :: data + + res(1) = LBOUND(container%data, 1) ! Should return 2 + res(2) = LBOUND(container%data, 2) ! Should return 3 + res(3) = UBOUND(container%data, 1) ! Should return 5 + res(4) = UBOUND(container%data, 2) ! Should return 9 + END SUBROUTINE +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [2, 3, 5, 9]) + +def test_fortran_frontend_bound_structure_recursive(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE test_types + IMPLICIT NONE + + TYPE inner_container + INTEGER, DIMENSION(-1:2, 0:3) :: inner_data + END TYPE + + TYPE middle_container + INTEGER, DIMENSION(2:5, 3:9) :: middle_data + TYPE(inner_container) :: inner + END TYPE + + TYPE array_container + INTEGER, DIMENSION(0:3, -2:4) :: outer_data + TYPE(middle_container) :: middle + END TYPE +END MODULE + +MODULE test_bounds + USE test_types + IMPLICIT NONE + + CONTAINS + + SUBROUTINE intrinsic_bound_test_function( res) + TYPE(array_container) :: container + INTEGER, DIMENSION(4) :: res + + CALL intrinsic_bound_test_function_impl(container, res) + END SUBROUTINE + + SUBROUTINE intrinsic_bound_test_function_impl(container, res) + TYPE(array_container), INTENT(IN) :: container + INTEGER, DIMENSION(4) :: res + + res(1) = LBOUND(container%middle%inner%inner_data, 1) ! Should return -1 + res(2) = LBOUND(container%middle%inner%inner_data, 2) ! Should return 0 + res(3) = UBOUND(container%middle%inner%inner_data, 1) ! Should return 2 + res(4) = UBOUND(container%middle%inner%inner_data, 2) ! Should return 3 + END SUBROUTINE +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [-1, 0, 2, 3]) + +def test_fortran_frontend_bound_structure_recursive_allocatable(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE test_types + IMPLICIT NONE + + TYPE inner_container + INTEGER, ALLOCATABLE, DIMENSION(:, :) :: inner_data + END TYPE + + TYPE middle_container + INTEGER, ALLOCATABLE, DIMENSION(:, :) :: middle_data + TYPE(inner_container) :: inner + END TYPE + + TYPE array_container + INTEGER, ALLOCATABLE, DIMENSION(:, :) :: outer_data + TYPE(middle_container) :: middle + END TYPE +END MODULE + +MODULE test_bounds + USE test_types + IMPLICIT NONE + + CONTAINS + + SUBROUTINE intrinsic_bound_test_function(res) + IMPLICIT NONE + TYPE(array_container) :: container + INTEGER, DIMENSION(4) :: res + + ALLOCATE(container%middle%inner%inner_data(-1:2, 0:3)) + CALL intrinsic_bound_test_function_impl(container, res) + DEALLOCATE(container%middle%inner%inner_data) + + END SUBROUTINE + + SUBROUTINE intrinsic_bound_test_function_impl(container, res) + IMPLICIT NONE + TYPE(array_container), INTENT(IN) :: container + INTEGER, DIMENSION(4) :: res + + res(1) = LBOUND(container%middle%inner%inner_data, 1) ! Should return -1 + res(2) = LBOUND(container%middle%inner%inner_data, 2) ! Should return 0 + res(3) = UBOUND(container%middle%inner%inner_data, 1) ! Should return 2 + res(4) = UBOUND(container%middle%inner%inner_data, 2) ! Should return 3 + END SUBROUTINE +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [-1, 0, 2, 3]) + +if __name__ == "__main__": + + test_fortran_frontend_bound() + test_fortran_frontend_bound_offsets() + test_fortran_frontend_bound_assumed() + test_fortran_frontend_bound_assumed_offsets() + test_fortran_frontend_bound_allocatable_offsets() + test_fortran_frontend_bound_structure() + test_fortran_frontend_bound_structure_override() + test_fortran_frontend_bound_structure_recursive() + #test_fortran_frontend_bound_structure_recursive_allocatable() diff --git a/tests/fortran/intrinsic_count_test.py b/tests/fortran/intrinsic_count_test.py index ef55f9dd55..ced135d1a6 100644 --- a/tests/fortran/intrinsic_count_test.py +++ b/tests/fortran/intrinsic_count_test.py @@ -24,7 +24,7 @@ def test_fortran_frontend_count_array(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -60,7 +60,7 @@ def test_fortran_frontend_count_array_dim(): """ with pytest.raises(NotImplementedError): - fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") def test_fortran_frontend_count_array_comparison(): @@ -91,7 +91,7 @@ def test_fortran_frontend_count_array_comparison(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -141,7 +141,7 @@ def test_fortran_frontend_count_array_scalar_comparison(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -187,7 +187,7 @@ def test_fortran_frontend_count_array_comparison_wrong_subset(): """ with pytest.raises(TypeError): - fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") def test_fortran_frontend_count_array_2d(): test_string = """ @@ -207,7 +207,7 @@ def test_fortran_frontend_count_array_2d(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -256,7 +256,7 @@ def test_fortran_frontend_count_array_comparison_2d(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() @@ -297,7 +297,7 @@ def test_fortran_frontend_count_array_comparison_2d_subset(): END SUBROUTINE intrinsic_count_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") sdfg.simplify(verbose=True) sdfg.compile() diff --git a/tests/fortran/intrinsic_math_test.py b/tests/fortran/intrinsic_math_test.py index e1fc469beb..d407ba3bac 100644 --- a/tests/fortran/intrinsic_math_test.py +++ b/tests/fortran/intrinsic_math_test.py @@ -31,7 +31,7 @@ def test_fortran_frontend_min_max(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_min_max", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_min_max", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -73,7 +73,7 @@ def test_fortran_frontend_sqrt(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_sqrt", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_sqrt", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -88,42 +88,74 @@ def test_fortran_frontend_sqrt(): for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 -def test_fortran_frontend_abs(): +def test_fortran_frontend_sqrt_structure(): test_string = """ - PROGRAM intrinsic_math_test_abs + module lib + implicit none + type test_type + double precision, dimension(2) :: input_data + end type + + type test_type2 + type(test_type) :: var + integer :: test_variable + end type + end module lib + + PROGRAM intrinsic_math_test_sqrt + use lib, only: test_type2 implicit none + double precision, dimension(2) :: d double precision, dimension(2) :: res CALL intrinsic_math_test_function(d, res) end SUBROUTINE intrinsic_math_test_function(d, res) + use lib, only: test_type2 + implicit none + double precision, dimension(2) :: d double precision, dimension(2) :: res + type(test_type2) :: data - res(1) = ABS(d(1)) - res(2) = ABS(d(2)) + data%var%input_data = d + + CALL intrinsic_math_test_function2(res, data) END SUBROUTINE intrinsic_math_test_function + + SUBROUTINE intrinsic_math_test_function2(res, data) + use lib, only: test_type2 + implicit none + double precision, dimension(2) :: res + type(test_type2) :: data + + res(1) = MOD(data%var%input_data(1), 5.0D0) + res(2) = MOD(data%var%input_data(2), 5.0D0) + + END SUBROUTINE intrinsic_math_test_function2 """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_abs", False) - sdfg.simplify(verbose=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_sqrt", True) + sdfg.validate() + #sdfg.simplify(verbose=True) sdfg.compile() size = 2 d = np.full([size], 42, order="F", dtype=np.float64) - d[0] = -30 - d[1] = 40 + d[0] = 2 + d[1] = 5 res = np.full([2], 42, order="F", dtype=np.float64) sdfg(d=d, res=res) + py_res = np.sqrt(d) - assert res[0] == 30 - assert res[1] == 40 + for f_res, p_res in zip(res, py_res): + assert abs(f_res - p_res) < 10**-9 -def test_fortran_frontend_exp(): +def test_fortran_frontend_abs(): test_string = """ - PROGRAM intrinsic_math_test_exp + PROGRAM intrinsic_math_test_abs implicit none double precision, dimension(2) :: d double precision, dimension(2) :: res @@ -134,30 +166,29 @@ def test_fortran_frontend_exp(): double precision, dimension(2) :: d double precision, dimension(2) :: res - res(1) = EXP(d(1)) - res(2) = EXP(d(2)) + res(1) = ABS(d(1)) + res(2) = ABS(d(2)) END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_abs", True) sdfg.simplify(verbose=True) sdfg.compile() size = 2 d = np.full([size], 42, order="F", dtype=np.float64) - d[0] = 2 - d[1] = 4.5 + d[0] = -30 + d[1] = 40 res = np.full([2], 42, order="F", dtype=np.float64) sdfg(d=d, res=res) - py_res = np.exp(d) - for f_res, p_res in zip(res, py_res): - assert abs(f_res - p_res) < 10**-9 + assert res[0] == 30 + assert res[1] == 40 -def test_fortran_frontend_log(): +def test_fortran_frontend_exp(): test_string = """ - PROGRAM intrinsic_math_test_log + PROGRAM intrinsic_math_test_exp implicit none double precision, dimension(2) :: d double precision, dimension(2) :: res @@ -168,23 +199,23 @@ def test_fortran_frontend_log(): double precision, dimension(2) :: d double precision, dimension(2) :: res - res(1) = LOG(d(1)) - res(2) = LOG(d(2)) + res(1) = EXP(d(1)) + res(2) = EXP(d(2)) END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", True) sdfg.simplify(verbose=True) sdfg.compile() size = 2 d = np.full([size], 42, order="F", dtype=np.float64) - d[0] = 2.71 + d[0] = 2 d[1] = 4.5 res = np.full([2], 42, order="F", dtype=np.float64) sdfg(d=d, res=res) - py_res = np.log(d) + py_res = np.exp(d) for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 @@ -208,7 +239,7 @@ def test_fortran_frontend_log(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -246,7 +277,7 @@ def test_fortran_frontend_mod_float(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_mod", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_mod", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -295,7 +326,7 @@ def test_fortran_frontend_mod_integer(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -339,7 +370,7 @@ def test_fortran_frontend_modulo_float(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -388,7 +419,7 @@ def test_fortran_frontend_modulo_integer(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -431,7 +462,7 @@ def test_fortran_frontend_floor(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -474,7 +505,7 @@ def test_fortran_frontend_scale(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -519,7 +550,7 @@ def test_fortran_frontend_exponent(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -583,7 +614,7 @@ def test_fortran_frontend_int(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -653,7 +684,7 @@ def test_fortran_frontend_real(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -700,7 +731,7 @@ def test_fortran_frontend_trig(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -744,7 +775,7 @@ def test_fortran_frontend_hyperbolic(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -796,7 +827,7 @@ def test_fortran_frontend_trig_inverse(): END SUBROUTINE intrinsic_math_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", True) sdfg.simplify(verbose=True) sdfg.compile() @@ -828,6 +859,7 @@ def test_fortran_frontend_trig_inverse(): test_fortran_frontend_min_max() test_fortran_frontend_sqrt() + #test_fortran_frontend_sqrt_structure() test_fortran_frontend_abs() test_fortran_frontend_exp() test_fortran_frontend_log() diff --git a/tests/fortran/intrinsic_merge_test.py b/tests/fortran/intrinsic_merge_test.py index 1778b9c2fb..95421d843e 100644 --- a/tests/fortran/intrinsic_merge_test.py +++ b/tests/fortran/intrinsic_merge_test.py @@ -273,6 +273,240 @@ def test_fortran_frontend_merge_array_shift(): for val in res: assert val == 100 +def test_fortran_frontend_merge_nonarray(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + logical :: val(2) + double precision :: res(2) + CALL merge_test_function(val, res) + end + + SUBROUTINE merge_test_function(val, res) + logical :: val(2) + double precision :: res(2) + double precision :: input1 + double precision :: input2 + + input1 = 1 + input2 = 5 + + res(1) = MERGE(input1, input2, val(1)) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + val = np.full([1], 1, order="F", dtype=np.int32) + res = np.full([1], 40, order="F", dtype=np.float64) + + sdfg(val=val, res=res) + assert res[0] == 1 + + val[0] = 0 + sdfg(val=val, res=res) + assert res[0] == 5 + +def test_fortran_frontend_merge_recursive(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: input3 + integer, dimension(7) :: mask1 + integer, dimension(7) :: mask2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, input3, mask1, mask2, res) + end + + SUBROUTINE merge_test_function(input1, input2, input3, mask1, mask2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: input3 + integer, dimension(7) :: mask1 + integer, dimension(7) :: mask2 + double precision, dimension(7) :: res + + res = MERGE(MERGE(input1, input2, mask1), input3, mask2) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + third = np.full([size], 43, order="F", dtype=np.float64) + mask1 = np.full([size], 0, order="F", dtype=np.int32) + mask2 = np.full([size], 1, order="F", dtype=np.int32) + res = np.full([size], 40, order="F", dtype=np.float64) + + for i in range(int(size/2)): + mask1[i] = 1 + + mask2[-1] = 0 + + sdfg(input1=first, input2=second, input3=third, mask1=mask1, mask2=mask2, res=res) + + assert np.allclose(res, [13, 13, 13, 42, 42, 42, 43]) + +def test_fortran_frontend_merge_scalar(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + + res(1) = MERGE(input1(1), input2(1), mask(1)) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.save('test.sdfg') + #sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + mask = np.full([size], 0, order="F", dtype=np.int32) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, mask=mask, res=res) + + assert res[0] == 42 + for val in res[1:]: + assert val == 40 + + mask[0] = 1 + sdfg(input1=first, input2=second, mask=mask, res=res) + assert res[0] == 13 + for val in res[1:]: + assert val == 40 + + +def test_fortran_frontend_merge_scalar2(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + + res(1) = MERGE(input1(1), 0.0, mask(1)) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + mask = np.full([size], 0, order="F", dtype=np.int32) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, mask=mask, res=res) + assert res[0] == 0 + + mask[:] = 1 + sdfg(input1=first, input2=second, mask=mask, res=res) + assert res[0] == 13 + +def test_fortran_frontend_merge_scalar3(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + integer, dimension(7) :: mask2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask, mask2, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask, mask2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + integer, dimension(7) :: mask2 + double precision, dimension(7) :: res + + res(1) = MERGE(input1(1), 0.0, mask(1) > mask2(1) .AND. mask2(2) == 0) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + mask = np.full([size], 0, order="F", dtype=np.int32) + mask2 = np.full([size], 0, order="F", dtype=np.int32) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, mask=mask, mask2=mask2, res=res) + assert res[0] == 0 + + mask[:] = 1 + sdfg(input1=first, input2=second, mask=mask, mask2=mask2, res=res) + assert res[0] == 13 if __name__ == "__main__": @@ -281,3 +515,9 @@ def test_fortran_frontend_merge_array_shift(): test_fortran_frontend_merge_comparison_arrays() test_fortran_frontend_merge_comparison_arrays_offset() test_fortran_frontend_merge_array_shift() + test_fortran_frontend_merge_nonarray() + test_fortran_frontend_merge_recursive() + test_fortran_frontend_merge_recursive() + test_fortran_frontend_merge_scalar() + test_fortran_frontend_merge_scalar2() + test_fortran_frontend_merge_scalar3() diff --git a/tests/fortran/intrinsic_minmaxval_test.py b/tests/fortran/intrinsic_minmaxval_test.py index 6a32237d37..5c0cb2cca6 100644 --- a/tests/fortran/intrinsic_minmaxval_test.py +++ b/tests/fortran/intrinsic_minmaxval_test.py @@ -3,6 +3,7 @@ import numpy as np from dace.frontend.fortran import ast_transforms, fortran_parser +from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string def test_fortran_frontend_minval_double(): """ @@ -244,9 +245,59 @@ def test_fortran_frontend_maxval_int(): # It should be the dace max for integer assert res[3] == np.iinfo(np.int32).min +def test_fortran_frontend_minval_struct(): + sources, main = SourceCodeBuilder().add_file(""" +MODULE test_types + IMPLICIT NONE + TYPE array_container + INTEGER, DIMENSION(7) :: data + END TYPE array_container +END MODULE + +MODULE test_minval + USE test_types + IMPLICIT NONE + + CONTAINS + + SUBROUTINE minval_test_func(input, res) + TYPE(array_container) :: container + INTEGER, DIMENSION(7) :: input + INTEGER, DIMENSION(3) :: res + + container%data = input + + CALL minval_test_func_internal(container, res) + END SUBROUTINE + + SUBROUTINE minval_test_func_internal(container, res) + TYPE(array_container), INTENT(IN) :: container + INTEGER, DIMENSION(3) :: res + + res(1) = MAXVAL(container%data) + res(2) = MAXVAL(container%data(:)) + res(3) = MAXVAL(container%data(3:6)) + END SUBROUTINE +END MODULE +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'test_minval.minval_test_func') + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input = np.full([size], 0, order="F", dtype=np.int32) + for i in range(size): + d[i] = i + 1 + res = np.full([4], 42, order="F", dtype=np.int32) + # FIXME: this test is unfinished + sdfg(d=d, res=res) + print(res) + if __name__ == "__main__": - test_fortran_frontend_minval_double() - test_fortran_frontend_minval_int() - test_fortran_frontend_maxval_double() - test_fortran_frontend_maxval_int() + #test_fortran_frontend_minval_double() + #test_fortran_frontend_minval_int() + #test_fortran_frontend_maxval_double() + #test_fortran_frontend_maxval_int() + + test_fortran_frontend_minval_struct() diff --git a/tests/fortran/long_tasklet_test.py b/tests/fortran/long_tasklet_test.py new file mode 100644 index 0000000000..eb59a4bea6 --- /dev/null +++ b/tests/fortran/long_tasklet_test.py @@ -0,0 +1,54 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.frontend.fortran import fortran_parser + +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +import numpy as np + +def test_fortran_frontend_long_tasklet(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM long_tasklet_test + implicit none + + + type test_type + integer :: indices(5) + integer :: start + integer :: end + end type + + double precision d(5) + double precision, dimension(5) :: arr + double precision, dimension(50:54) :: arr3 + CALL long_tasklet_test_function(d) + end + + SUBROUTINE long_tasklet_test_function(d) + double precision d(5) + double precision, dimension(50:54) :: arr4 + double precision, dimension(5) :: arr + type(test_type) :: ind + + arr(:)=2.0 + ind%indices(:)=1 + d(2)=5.5 + d(1)=arr(1)*arr(ind%indices(1))!+arr(2,2,2)*arr(ind%indices(2,2,2),2,2)!+arr(3,3,3)*arr(ind%indices(3,3,3),3,3) + + END SUBROUTINE long_tasklet_test_function + """ + sources={} + sources["long_tasklet_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "long_tasklet_test", True, sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[1] == 5.5) + assert (a[0] == 4) + +if __name__ == "__main__": + + test_fortran_frontend_long_tasklet() diff --git a/tests/fortran/missing_func_test.py b/tests/fortran/missing_func_test.py new file mode 100644 index 0000000000..1b55dd324d --- /dev/null +++ b/tests/fortran/missing_func_test.py @@ -0,0 +1,146 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +from dace.transformation.passes.lift_struct_views import LiftStructViews +from dace.transformation import pass_pipeline as ppl + + +def test_fortran_frontend_missing_func(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM missing_test + implicit none + + + REAL :: d(5,5) + + CALL missing_test_function(d) + end + + + SUBROUTINE missing_test_function(d) + REAL d(5,5) + REAL z(5) + + CALL init_zero_contiguous_dp(z, 5, opt_acc_async=.TRUE.,lacc=.FALSE.) + d(2,1) = 5.5 + z(1) + + END SUBROUTINE missing_test_function + + SUBROUTINE init_contiguous_dp(var, n, v, opt_acc_async, lacc) + INTEGER, INTENT(in) :: n + REAL, INTENT(out) :: var(n) + REAL, INTENT(in) :: v + LOGICAL, INTENT(in), OPTIONAL :: opt_acc_async + LOGICAL, INTENT(in), OPTIONAL :: lacc + + INTEGER :: i + LOGICAL :: lzacc + + CALL set_acc_host_or_device(lzacc, lacc) + + DO i = 1, n + var(i) = v + END DO + + CALL acc_wait_if_requested(1, opt_acc_async) + END SUBROUTINE init_contiguous_dp + + SUBROUTINE init_zero_contiguous_dp(var, n, opt_acc_async, lacc) + INTEGER, INTENT(in) :: n + REAL, INTENT(out) :: var(n) + LOGICAL, INTENT(IN), OPTIONAL :: opt_acc_async + LOGICAL, INTENT(IN), OPTIONAL :: lacc + + + CALL init_contiguous_dp(var, n, 0.0, opt_acc_async, lacc) + var(1)=var(1)+1.0 + + END SUBROUTINE init_zero_contiguous_dp + + + SUBROUTINE set_acc_host_or_device(lzacc, lacc) + LOGICAL, INTENT(out) :: lzacc + LOGICAL, INTENT(in), OPTIONAL :: lacc + + lzacc = .FALSE. + + END SUBROUTINE set_acc_host_or_device + + SUBROUTINE acc_wait_if_requested(acc_async_queue, opt_acc_async) + INTEGER, INTENT(IN) :: acc_async_queue + LOGICAL, INTENT(IN), OPTIONAL :: opt_acc_async + + + END SUBROUTINE acc_wait_if_requested + """ + sources={} + sources["missing_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "missing_test", True, sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 6.5) + assert (a[2, 0] == 42) + +def test_fortran_frontend_missing_extraction(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM missing_extraction_test + implicit none + + + REAL :: d(5,5) + + CALL missing_extraction_test_function(d) + end + + + SUBROUTINE missing_extraction_test_function(d) + REAL d(5,5) + REAL z(5) + integer :: jk = 5 + integer :: nrdmax_jg = 3 + DO jk = MAX(0,nrdmax_jg-2), 2 + d(jk,jk) = 17 + ENDDO + d(2,1) = 5.5 + + END SUBROUTINE missing_extraction_test_function + + """ + sources={} + sources["missing_extraction_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "missing_extraction_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 17) + assert (a[1, 0] == 5.5) + assert (a[2, 0] == 42) + +if __name__ == "__main__": + test_fortran_frontend_missing_func() + test_fortran_frontend_missing_extraction() + \ No newline at end of file diff --git a/tests/fortran/multisdfg_construction_test.py b/tests/fortran/multisdfg_construction_test.py new file mode 100644 index 0000000000..d1e485465e --- /dev/null +++ b/tests/fortran/multisdfg_construction_test.py @@ -0,0 +1,161 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Dict, List + +import numpy as np + +from dace.frontend.fortran.ast_components import InternalFortranAst +from dace.frontend.fortran.ast_internal_classes import FNode +from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast, SDFGConfig, \ + create_sdfg_from_internal_ast +from tests.fortran.fortran_test_helper import SourceCodeBuilder + + +def construct_internal_ast(sources: Dict[str, str], entry_points: List[str]): + assert 'main.f90' in sources + entry_points = [tuple(ep.split('.')) for ep in entry_points] + cfg = ParseConfig(sources['main.f90'], sources, [], entry_points=entry_points) + iast, prog = create_internal_ast(cfg) + return iast, prog + + +def construct_sdfg(iast: InternalFortranAst, prog: FNode, entry_points: List[str]): + entry_points = [list(ep.split('.')) for ep in entry_points] + entry_points = {ep[-1]: ep for ep in entry_points} + cfg = SDFGConfig(entry_points) + g = create_sdfg_from_internal_ast(iast, prog, cfg) + return g + + +def test_minimal(): + """ + A simple program to just verify that we can produce compilable SDFGs. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + d(2) = 5.5 +end program main +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 +end subroutine fun +""").check_with_gfortran().get() + # Construct + entry_points = ['main', 'fun'] + iast, prog = construct_internal_ast(sources, entry_points) + gmap = construct_sdfg(iast, prog, entry_points) + + # Verify + assert gmap.keys() == {'main', 'fun'} + gmap['main'].compile() + # We will do nothing else here, since it's just a sanity check test. + + +def test_standalone_subroutines(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 +end subroutine fun + +subroutine not_fun(d, val) + implicit none + double precision, intent(in) :: val + double precision, intent(inout) :: d(4) + d(4) = val +end subroutine not_fun +""").check_with_gfortran().get() + # Construct + entry_points = ['fun', 'not_fun'] + iast, prog = construct_internal_ast(sources, entry_points) + gmap = construct_sdfg(iast, prog, entry_points) + + # Verify + assert gmap.keys() == {'fun', 'not_fun'} + d = np.full([4], 0, dtype=np.float64) + + fun = gmap['fun'].compile() + fun(d=d) + assert np.allclose(d, [0, 4.2, 0, 0]) + not_fun = gmap['not_fun'].compile() + not_fun(d=d, val=5.5) + assert np.allclose(d, [0, 4.2, 0, 5.5]) + + +def test_subroutines_from_module(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + d(2) = 4.2 + end subroutine fun + + subroutine not_fun(d, val) + implicit none + double precision, intent(in) :: val + double precision, intent(inout) :: d(4) + d(4) = val + end subroutine not_fun +end module lib +""").add_file(""" +program main + use lib + implicit none +end program main +""").check_with_gfortran().get() + # Construct + entry_points = ['lib.fun', 'lib.not_fun'] + iast, prog = construct_internal_ast(sources, entry_points) + gmap = construct_sdfg(iast, prog, entry_points) + + # Verify + assert gmap.keys() == {'fun', 'not_fun'} + d = np.full([4], 0, dtype=np.float64) + + fun = gmap['fun'].compile() + fun(d=d) + assert np.allclose(d, [0, 4.2, 0, 0]) + not_fun = gmap['not_fun'].compile() + not_fun(d=d, val=5.5) + assert np.allclose(d, [0, 4.2, 0, 5.5]) + + +def test_subroutine_with_local_variable(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(d) + implicit none + double precision, intent(inout) :: d(4) + double precision :: e(4) + e(:) = 1.0 + e(2) = 4.2 + d(:) = e(:) +end subroutine fun +""").check_with_gfortran().get() + # Construct + entry_points = ['fun'] + iast, prog = construct_internal_ast(sources, entry_points) + gmap = construct_sdfg(iast, prog, entry_points) + + # Verify + assert gmap.keys() == {'fun'} + d = np.full([4], 0, dtype=np.float64) + + fun = gmap['fun'].compile() + fun(d=d) + assert np.allclose(d, [1, 4.2, 1, 1]) diff --git a/tests/fortran/nested_array_test.py b/tests/fortran/nested_array_test.py new file mode 100644 index 0000000000..6c4cb14535 --- /dev/null +++ b/tests/fortran/nested_array_test.py @@ -0,0 +1,101 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_nested_array_access(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM access_test + implicit none + double precision d(4) + CALL nested_array_access_test_function(d) + end + + SUBROUTINE nested_array_access_test_function(d) + double precision d(4) + integer test(3,3,3) + integer indices(3,3,3) + indices(1,1,1)=2 + indices(1,1,2)=3 + indices(1,1,3)=1 + test(indices(1,1,1),indices(1,1,2),indices(1,1,3))=2 + d(test(2,3,1))=5.5 + + END SUBROUTINE nested_array_access_test_function + """ + sources={"nested_array_access_test_function": test_string} + sdfg = fortran_parser.create_sdfg_from_string(test_string, "nested_array_access_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0] == 42) + assert (a[1] == 5.5) + assert (a[2] == 42) + + +def test_fortran_frontend_nested_array_access2(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM access2_test + implicit none + double precision d(4) + simple_type + integer test1(3,3,3) + integer indices1(3,3,3) + + + CALL nested_array_access2_test_function(d,test1,indices1) + end + + SUBROUTINE nested_array_access2_test_function(d,test1,indices1) + + integer,pointer test1(:,:,:) + integer,pointer indices1(:,:,:) + double precision d(4) + integer,pointer test(:,:,:) + integer,pointer indices(:,:,:) + + test1=>test + indices1=>indices + indices(1,1,1)=2 + indices(1,1,2)=3 + indices(1,1,3)=1 + test(indices(1,1,1),indices(1,1,2),indices(1,1,3))=2 + d(test(2,3,1))=5.5 + + END SUBROUTINE nested_array_access2_test_function + """ + sources={"nested_array_access2_test_function": test_string} + sdfg = fortran_parser.create_sdfg_from_string(test_string, "nested_array_access2_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0] == 42) + assert (a[1] == 5.5) + assert (a[2] == 42) + +if __name__ == "__main__": + + test_fortran_frontend_nested_array_access2() + diff --git a/tests/fortran/non-interactive/fortran_int_init_test.py b/tests/fortran/non-interactive/fortran_int_init_test.py new file mode 100644 index 0000000000..7632db6d19 --- /dev/null +++ b/tests/fortran/non-interactive/fortran_int_init_test.py @@ -0,0 +1,66 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_int_init(): + """ + Tests that the power intrinsic is correctly parsed and translated to DaCe. (should become a*a) + """ + test_string = """ + PROGRAM int_init_test + implicit none + integer d(2) + CALL int_init_test_function(d) + end + + SUBROUTINE int_init_test_function(d) + integer d(2) + d(1)=INT(z'000000ffffffffff',i8) + END SUBROUTINE int_init_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "int_init_test",False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + sdfg.compile() + # sdfg = fortran_parser.create_sdfg_from_string(test_string, "int_init_test") + # sdfg.simplify(verbose=True) + # d = np.full([2], 42, order="F", dtype=np.int64) + # sdfg(d=d) + # assert (d[0] == 400) + + + +if __name__ == "__main__": + + + + test_fortran_frontend_int_init() + diff --git a/tests/fortran/non-interactive/function_test.py b/tests/fortran/non-interactive/function_test.py new file mode 100644 index 0000000000..ec95555c8f --- /dev/null +++ b/tests/fortran/non-interactive/function_test.py @@ -0,0 +1,409 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_function_test(): + """ + Tests to check whether Fortran array slices are correctly translates to DaCe views. + """ + test_name = "function_test" + test_string = """ + PROGRAM """ + test_name + """_program +implicit none +INTEGER a +INTEGER lon(10) +INTEGER lat(10) + +a=function_test_function(1,lon,lat,10) + +end + + + INTEGER FUNCTION function_test_function (lonc, lon, lat, n) + INTEGER, INTENT(in) :: n + REAL, INTENT(in) :: lonc + REAL, INTENT(in) :: lon(n), lat(n) + REAL :: pi=3.14 + REAL :: lonl(n), latl(n) + + REAL :: area + + INTEGER :: i,j + + lonl(:) = lon(:) + latl(:) = lat(:) + + DO i = 1, n + lonl(i) = lonl(i) - lonc + IF (lonl(i) < -pi) THEN + lonl(i) = pi+MOD(lonl(i), pi) + ENDIF + IF (lonl(i) > pi) THEN + lonl(i) = -pi+MOD(lonl(i), pi) + ENDIF + ENDDO + + area = 0.0 + DO i = 1, n + j = MOD(i,n)+1 + area = area+lonl(i)*latl(j) + area = area-latl(i)*lonl(j) + ENDDO + + IF (area >= 0.0) THEN + function_test_function = +1 + ELSE + function_test_function = -1 + END IF + + END FUNCTION function_test_function + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_function_test2(): + """ + Tests to check whether Fortran array slices are correctly translates to DaCe views. + """ + test_name = "function2_test" + test_string = """ + PROGRAM """ + test_name + """_program +implicit none +REAL x(3) +REAL y(3) +REAL z + +z=function2_test_function(x,y) + +end + + + +PURE FUNCTION function2_test_function (p_x, p_y) result (p_arc) + REAL, INTENT(in) :: p_x(3), p_y(3) ! endpoints + + REAL :: p_arc ! length of geodesic arc + + REAL :: z_lx, z_ly ! length of vector p_x and p_y + REAL :: z_cc ! cos of angle between endpoints + + !----------------------------------------------------------------------- + + !z_lx = SQRT(DOT_PRODUCT(p_x,p_x)) + !z_ly = SQRT(DOT_PRODUCT(p_y,p_y)) + + !z_cc = DOT_PRODUCT(p_x, p_y)/(z_lx*z_ly) + + ! in case we get numerically incorrect solutions + + !IF (z_cc > 1._wp ) z_cc = 1.0 + !IF (z_cc < -1._wp ) z_cc = -1.0 + z_cc= p_x(1)*p_y(1)+p_x(2)*p_y(2)+p_x(3)*p_y(3) + p_arc = ACOS(z_cc) + + END FUNCTION function2_test_function + + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + + + +def test_fortran_frontend_function_test3(): + """ + Tests to check whether Fortran array slices are correctly translates to DaCe views. + """ + test_name = "function3_test" + test_string = """ +PROGRAM """ + test_name + """_program + + implicit none + + REAL z + + ! cartesian coordinate class + TYPE t_cartesian_coordinates + REAL :: x(3) + END TYPE t_cartesian_coordinates + + ! geographical coordinate class + TYPE t_geographical_coordinates + REAL :: lon + REAL :: lat + END TYPE t_geographical_coordinates + + ! the two coordinates on the tangent plane + TYPE t_tangent_vectors + REAL :: v1 + REAL :: v2 + END TYPE t_tangent_vectors + + ! line class + TYPE t_line + TYPE(t_geographical_coordinates) :: p1(10) + TYPE(t_geographical_coordinates) :: p2 + END TYPE t_line + + TYPE(t_line) :: v + TYPE(t_geographical_coordinates) :: gp1_1 + TYPE(t_geographical_coordinates) :: gp1_2 + TYPE(t_geographical_coordinates) :: gp1_3 + TYPE(t_geographical_coordinates) :: gp1_4 + TYPE(t_geographical_coordinates) :: gp1_5 + TYPE(t_geographical_coordinates) :: gp1_6 + TYPE(t_geographical_coordinates) :: gp1_7 + TYPE(t_geographical_coordinates) :: gp1_8 + TYPE(t_geographical_coordinates) :: gp1_9 + TYPE(t_geographical_coordinates) :: gp1_10 + + gp1_1%lon = 1.0 + gp1_1%lat = 1.0 + gp1_2%lon = 2.0 + gp1_2%lat = 2.0 + gp1_3%lon = 3.0 + gp1_3%lat = 3.0 + gp1_4%lon = 4.0 + gp1_4%lat = 4.0 + gp1_5%lon = 5.0 + gp1_5%lat = 5.0 + gp1_6%lon = 6.0 + gp1_6%lat = 6.0 + gp1_7%lon = 7.0 + gp1_7%lat = 7.0 + gp1_8%lon = 8.0 + gp1_8%lat = 8.0 + gp1_9%lon = 9.0 + gp1_9%lat = 9.0 + gp1_10%lon = 10.0 + gp1_10%lat = 10.0 + + v%p1(1) = gp1_1 + v%p1(2) = gp1_2 + v%p1(3) = gp1_3 + v%p1(4) = gp1_4 + v%p1(5) = gp1_5 + v%p1(6) = gp1_6 + v%p1(7) = gp1_7 + v%p1(8) = gp1_8 + v%p1(9) = gp1_9 + v%p1(10) = gp1_10 + + z = function3_test_function(v) + +END PROGRAM """ + test_name + """_program + +ELEMENTAL FUNCTION function3_test_function (v) result(length) + TYPE(t_line), INTENT(in) :: v + REAL :: length + REAL :: segment + REAL :: dlon + REAL :: dlat + + length = 0 + DO i = 1, 9 + segment = 0 + dlon = 0 + dlat = 0 + dlon = v%p1(i + 1)%lon - v%p1(i)%lon + dlat = v%p1(i + 1)%lat - v%p1(i)%lat + segment = dlon * dlon + dlat * dlat + length = length + SQRT(segment) + ENDDO + +END FUNCTION function3_test_function +""" + + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + sdfg.compile() + + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_function_test4(): + """ + Test for elemental functions + """ + test_name = "function4_test" + test_string = """ + PROGRAM """ + test_name + """_program +implicit none + +REAL b +REAL v +REAL z(10) +z(:)=4.0 + +b=function4_test_function(v,z) + +end + + + + FUNCTION function4_test_function (v,z) result(length) + REAL, INTENT(in) :: v + REAL z(10) + REAL :: length + + +REAL a(10) +REAL b + + + +a=norm(z) +length=norm(v)+a + + END FUNCTION function4_test_function + + ELEMENTAL FUNCTION norm (v) result(length) + REAL, INTENT(in) :: v + REAL :: length + + + length = v*v + + END FUNCTION norm + + + + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + sdfg.compile() + + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_function_test5(): + """ + Test for elemental functions + """ + test_name = "function5_test" + test_string = """ + PROGRAM """ + test_name + """_program +implicit none + +REAL b +REAL v +REAL z(10) +REAL y(10) +INTEGER proc(10) +INTEGER keyval(10) +z(:)=4.0 + +CALL function5_test_function(z,y,10,1,2,proc,keyval,3,0) + +end + + + + SUBROUTINE function5_test_function(in_field, out_field, n, op, loc_op, & + proc_id, keyval, comm, root) + INTEGER, INTENT(in) :: n, op, loc_op + REAL, INTENT(in) :: in_field(n) + REAL, INTENT(out) :: out_field(n) + + INTEGER, OPTIONAL, INTENT(inout) :: proc_id(n) + INTEGER, OPTIONAL, INTENT(inout) :: keyval(n) + INTEGER, OPTIONAL, INTENT(in) :: root + INTEGER, OPTIONAL, INTENT(in) :: comm + + + out_field = in_field + + END SUBROUTINE function5_test_function + + + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.simplify(verbose=True) + sdfg.view() + sdfg.compile() + +if __name__ == "__main__": + + #test_fortran_frontend_function_test() + #test_fortran_frontend_function_test2() + #test_fortran_frontend_function_test3() + test_fortran_frontend_function_test4() + #test_fortran_frontend_function_test5() + #test_fortran_frontend_view_test_2() + #test_fortran_frontend_view_test_3() diff --git a/tests/fortran/non-interactive/pointers_test.py b/tests/fortran/non-interactive/pointers_test.py new file mode 100644 index 0000000000..3b98595ab9 --- /dev/null +++ b/tests/fortran/non-interactive/pointers_test.py @@ -0,0 +1,81 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +@pytest.mark.skip(reason="Interactive test (opens SDFG).") +def test_fortran_frontend_pointer_test(): + """ + Tests to check whether Fortran array slices are correctly translates to DaCe views. + """ + test_name = "pointer_test" + test_string = """ + PROGRAM """ + test_name + """_program +implicit none +REAL lon(10) +REAL lout(10) +TYPE simple_type + REAL:: w(5,5,5),z(5) + INTEGER:: a +END TYPE simple_type + +lon(:) = 1.0 +CALL pointer_test_function(lon,lout) + +end + + + SUBROUTINE pointer_test_function (lon,lout) + REAL, INTENT(in) :: lon(10) + REAL, INTENT(out) :: lout(10) + TYPE(simple_type) :: s + REAL :: area + REAL, POINTER, CONTIGUOUS :: p_area + INTEGER :: i,j + + s%w(1,1,1)=5.5 + lout(:)=0.0 + p_area => s%w + + lout(1)=p_area(1,1,1)+lon(1) + + + END SUBROUTINE pointer_test_function + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.validate() + sdfg.simplify(verbose=True) + sdfg.view() + + + + +if __name__ == "__main__": + + test_fortran_frontend_pointer_test() diff --git a/tests/fortran/view_test.py b/tests/fortran/non-interactive/view_test.py similarity index 69% rename from tests/fortran/view_test.py rename to tests/fortran/non-interactive/view_test.py index 8c00d47e98..eea4ca1c90 100644 --- a/tests/fortran/view_test.py +++ b/tests/fortran/non-interactive/view_test.py @@ -18,6 +18,7 @@ import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +@pytest.mark.skip(reason="Interactive test (opens SDFG).") def test_fortran_frontend_view_test(): """ Tests to check whether Fortran array slices are correctly translates to DaCe views. @@ -62,7 +63,45 @@ def test_fortran_frontend_view_test(): END SUBROUTINE viewlens """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name) + sdfg2 = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg2.view() + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + for state in sdfg.nodes(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + if node.path!="": + print("TEST: "+node.path) + tmp_sdfg = SDFG.from_file(node.path) + node.sdfg = tmp_sdfg + node.sdfg.parent = state + node.sdfg.parent_sdfg = sdfg + node.sdfg.update_sdfg_list([]) + node.sdfg.parent_nsdfg_node = node + node.path="" + for sd in sdfg.all_sdfgs_recursive(): + for state in sd.nodes(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + if node.path!="": + print("TEST: "+node.path) + tmp_sdfg = SDFG.from_file(node.path) + node.sdfg = tmp_sdfg + node.sdfg.parent = state + node.sdfg.parent_sdfg = sd + node.sdfg.update_sdfg_list([]) + node.sdfg.parent_nsdfg_node = node + node.path="" + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.NestedSDFG): + if node.sdfg is not None: + if 'test_function' in node.sdfg.name: + sdfg = node.sdfg + break + sdfg.parent = None + sdfg.parent_sdfg = None + sdfg.parent_nsdfg_node = None + sdfg.reset_sdfg_list() + sdfg.view() sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([1, 1, 2], 42, order="F", dtype=np.float64) @@ -73,6 +112,7 @@ def test_fortran_frontend_view_test(): assert (b[0, 0, 0] == 4620) +@pytest.mark.skip(reason="Interactive test (opens SDFG).") def test_fortran_frontend_view_test_2(): """ Tests to check whether Fortran array slices are correctly translates to DaCe views. This case necessitates multiple views per array in the same context. @@ -117,8 +157,8 @@ def test_fortran_frontend_view_test_2(): END SUBROUTINE viewlens """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name) - sdfg.simplify(verbose=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + #sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([10, 11, 12], 42, order="F", dtype=np.float64) c = np.full([10, 11, 12], 42, order="F", dtype=np.float64) @@ -129,6 +169,7 @@ def test_fortran_frontend_view_test_2(): assert (c[1, 1, 1] == 84) +@pytest.mark.skip(reason="Interactive test (opens SDFG).") def test_fortran_frontend_view_test_3(): """ Tests to check whether Fortran array slices are correctly translates to DaCe views. This test generates multiple views from the same array in the same context. """ @@ -170,8 +211,8 @@ def test_fortran_frontend_view_test_3(): END SUBROUTINE viewlens """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name) - sdfg.simplify(verbose=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + #sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([10, 11, 12], 42, order="F", dtype=np.float64) @@ -184,5 +225,5 @@ def test_fortran_frontend_view_test_3(): if __name__ == "__main__": test_fortran_frontend_view_test() - test_fortran_frontend_view_test_2() - test_fortran_frontend_view_test_3() + #test_fortran_frontend_view_test_2() + #test_fortran_frontend_view_test_3() diff --git a/tests/fortran/offset_normalizer_test.py b/tests/fortran/offset_normalizer_test.py index b4138c1cac..3f8d9fc5c2 100644 --- a/tests/fortran/offset_normalizer_test.py +++ b/tests/fortran/offset_normalizer_test.py @@ -2,7 +2,7 @@ import numpy as np -from dace.frontend.fortran import ast_transforms, fortran_parser +from dace.frontend.fortran import ast_internal_classes, ast_transforms, fortran_parser def test_fortran_frontend_offset_normalizer_1d(): """ @@ -48,6 +48,60 @@ def test_fortran_frontend_offset_normalizer_1d(): for i in range(0,5): assert a[i] == (50+i)* 2 +def test_fortran_frontend_offset_normalizer_1d_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + integer :: arrsize + integer :: arrsize2 + double precision :: d(arrsize:arrsize2) + CALL index_test_function(d, arrsize, arrsize2) + end + + SUBROUTINE index_test_function(d, arrsize, arrsize2) + integer :: arrsize + integer :: arrsize2 + double precision :: d(arrsize:arrsize2) + + do i=arrsize,arrsize2 + d(i) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + idx_assignment = loop.body.execution[1] + assert isinstance(idx_assignment.rval.rval, ast_internal_classes.Name_Node) + assert idx_assignment.rval.rval.name == "arrsize" + + # Now test to verify it executes correctly + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + arrsize=50 + arrsize2=54 + assert len(sdfg.data('d').shape) == 1 + assert evaluate(sdfg.data('d').shape[0], {'arrsize': arrsize, 'arrsize2': arrsize2}) == 5 + + arrsize=50 + arrsize2=54 + a = np.full([arrsize2-arrsize+1], 42, order="F", dtype=np.float64) + sdfg(d=a, arrsize=arrsize, arrsize2=arrsize2) + for i in range(0, arrsize2 - arrsize + 1): + assert a[i] == (50+i)* 2 + def test_fortran_frontend_offset_normalizer_2d(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -103,6 +157,78 @@ def test_fortran_frontend_offset_normalizer_2d(): for j in range(0,3): assert a[i, j] == (50+i) * 2 + 3 * (7 + j) +def test_fortran_frontend_offset_normalizer_2d_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + CALL index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + end + + SUBROUTINE index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + + do i=arrsize, arrsize2 + do j=arrsize3, arrsize4 + d(i, j) = i * 2.0 + 3 * j + end do + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + nested_loop = loop.body.execution[1] + + idx = nested_loop.body.execution[1] + assert idx.lval.name == 'tmp_index_0' + assert isinstance(idx.rval.rval, ast_internal_classes.Name_Node) + assert idx.rval.rval.name == "arrsize" + + idx2 = nested_loop.body.execution[3] + assert idx2.lval.name == 'tmp_index_1' + assert isinstance(idx2.rval.rval, ast_internal_classes.Name_Node) + assert idx2.rval.rval.name == "arrsize3" + + # Now test to verify it executes correctly + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + values = { + 'arrsize': 50, + 'arrsize2': 54, + 'arrsize3': 7, + 'arrsize4': 9 + } + assert len(sdfg.data('d').shape) == 2 + assert evaluate(sdfg.data('d').shape[0], values) == 5 + assert evaluate(sdfg.data('d').shape[1], values) == 3 + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a, **values) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == (50+i) * 2 + 3 * (7 + j) + def test_fortran_frontend_offset_normalizer_2d_arr2loop(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -157,8 +283,147 @@ def test_fortran_frontend_offset_normalizer_2d_arr2loop(): for j in range(0,3): assert a[i, j] == (50 + i) * 2 +def test_fortran_frontend_offset_normalizer_2d_arr2loop_symbol(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + CALL index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + end + + SUBROUTINE index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + + do i=arrsize,arrsize2 + d(i, :) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + nested_loop = loop.body.execution[1] + + idx = nested_loop.body.execution[1] + assert idx.lval.name == 'tmp_index_0' + assert isinstance(idx.rval.rval, ast_internal_classes.Name_Node) + assert idx.rval.rval.name == "arrsize" + + idx2 = nested_loop.body.execution[3] + assert idx2.lval.name == 'tmp_index_1' + assert isinstance(idx2.rval.rval, ast_internal_classes.Name_Node) + assert idx2.rval.rval.name == "arrsize3" + + # Now test to verify it executes correctly with no normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + values = { + 'arrsize': 50, + 'arrsize2': 54, + 'arrsize3': 7, + 'arrsize4': 9 + } + assert len(sdfg.data('d').shape) == 2 + assert evaluate(sdfg.data('d').shape[0], values) == 5 + assert evaluate(sdfg.data('d').shape[1], values) == 3 + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a, **values) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == (50 + i) * 2 + +def test_fortran_frontend_offset_normalizer_struct(): + test_string = """ + PROGRAM index_offset_test + implicit none + + TYPE simple_type + double precision :: d(:, :) + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + END TYPE simple_type + + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + CALL index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + end + + SUBROUTINE index_test_function(d, arrsize, arrsize2, arrsize3, arrsize4) + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: arrsize4 + double precision, dimension(arrsize:arrsize2,arrsize3:arrsize4) :: d + type(simple_type) :: struct_data + + !struct_data%arrsize = arrsize + !struct_data%arrsize2 = arrsize2 + !struct_data%arrsize3 = arrsize3 + !struct_data%arrsize4 = arrsize4 + !struct_data%d = d + + !do i=struct_data%arrsize,struct_data%arrsize2 + ! struct_data%d(i, 1) = i * 2.0 + !end do + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + from dace.symbolic import evaluate + values = { + 'arrsize': 50, + 'arrsize2': 54, + 'arrsize3': 7, + 'arrsize4': 9 + } + assert len(sdfg.data('d').shape) == 2 + assert evaluate(sdfg.data('d').shape[0], values) == 5 + assert evaluate(sdfg.data('d').shape[1], values) == 3 + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a, **values) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == (50 + i) * 2 + if __name__ == "__main__": - test_fortran_frontend_offset_normalizer_1d() - test_fortran_frontend_offset_normalizer_2d() - test_fortran_frontend_offset_normalizer_2d_arr2loop() + #test_fortran_frontend_offset_normalizer_1d() + #test_fortran_frontend_offset_normalizer_2d() + #test_fortran_frontend_offset_normalizer_2d_arr2loop() + #test_fortran_frontend_offset_normalizer_1d_symbol() + #test_fortran_frontend_offset_normalizer_2d_symbol() + #test_fortran_frontend_offset_normalizer_2d_arr2loop_symbol() + test_fortran_frontend_offset_normalizer_struct() diff --git a/tests/fortran/optional_args_test.py b/tests/fortran/optional_args_test.py new file mode 100644 index 0000000000..11abc69994 --- /dev/null +++ b/tests/fortran/optional_args_test.py @@ -0,0 +1,116 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser +from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder + +def test_fortran_frontend_optional(): + + sources, main = SourceCodeBuilder().add_file(""" + + MODULE intrinsic_optional_test + INTERFACE + SUBROUTINE intrinsic_optional_test_function2(res, a) + integer, dimension(2) :: res + integer, optional :: a + END SUBROUTINE intrinsic_optional_test_function2 + END INTERFACE + END MODULE + + SUBROUTINE intrinsic_optional_test_function(res, res2, a) + USE intrinsic_optional_test + implicit none + integer, dimension(4) :: res + integer, dimension(4) :: res2 + integer :: a + + CALL intrinsic_optional_test_function2(res, a) + CALL intrinsic_optional_test_function2(res2) + + END SUBROUTINE intrinsic_optional_test_function + + SUBROUTINE intrinsic_optional_test_function2(res, a) + integer, dimension(2) :: res + integer, optional :: a + + res(1) = a + + END SUBROUTINE intrinsic_optional_test_function2 +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'intrinsic_optional_test_function', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + res2 = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res, res2=res2, a=5) + + assert res[0] == 5 + assert res2[0] == 0 + +def test_fortran_frontend_optional_complex(): + + sources, main = SourceCodeBuilder().add_file(""" + + MODULE intrinsic_optional_test + INTERFACE + SUBROUTINE intrinsic_optional_test_function2(res, a, b, c) + integer, dimension(5) :: res + integer, optional :: a + double precision, optional :: b + logical, optional :: c + END SUBROUTINE intrinsic_optional_test_function2 + END INTERFACE + END MODULE + + SUBROUTINE intrinsic_optional_test_function(res, res2, a, b, c) + USE intrinsic_optional_test + implicit none + integer, dimension(5) :: res + integer, dimension(5) :: res2 + integer :: a + double precision :: b + logical :: c + + CALL intrinsic_optional_test_function2(res, a, b) + CALL intrinsic_optional_test_function2(res2) + + END SUBROUTINE intrinsic_optional_test_function + + SUBROUTINE intrinsic_optional_test_function2(res, a, b, c) + integer, dimension(5) :: res + integer, optional :: a + double precision, optional :: b + logical, optional :: c + + res(1) = a + res(2) = b + res(3) = c + + END SUBROUTINE intrinsic_optional_test_function2 +""", 'main').check_with_gfortran().get() + sdfg = create_singular_sdfg_from_string(sources, 'intrinsic_optional_test_function', normalize_offsets=True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + res2 = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res, res2=res2, a=5, b=7, c=1) + + assert res[0] == 5 + assert res[1] == 7 + assert res[2] == 0 + + assert res2[0] == 0 + assert res2[1] == 0 + assert res2[2] == 0 + + +if __name__ == "__main__": + + test_fortran_frontend_optional() + test_fortran_frontend_optional_complex() diff --git a/tests/fortran/parent_test.py b/tests/fortran/parent_test.py index b1d08eaf37..1f66c81311 100644 --- a/tests/fortran/parent_test.py +++ b/tests/fortran/parent_test.py @@ -1,35 +1,35 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. - -from dace.frontend.fortran import fortran_parser - -import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +import dace.frontend.fortran.ast_transforms as ast_transforms +from dace.frontend.fortran.ast_internal_classes import Program_Node +from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast +from tests.fortran.fortran_test_helper import SourceCodeBuilder def test_fortran_frontend_parent(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - PROGRAM access_test - implicit none - double precision d(4) - d(1)=0 - CALL array_access_test_function(d) - end - - SUBROUTINE array_access_test_function(d) - double precision d(4) - - d(2)=5.5 - - END SUBROUTINE array_access_test_function - """ - ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + d(1) = 0 + call fun(d) +end program main + +subroutine fun(d) + double precision d(4) + d(2) = 5.5 +end subroutine fun +""", 'main').check_with_gfortran().get() + cfg = ParseConfig(main=sources['main.f90'], sources=sources) + _, ast = create_internal_ast(cfg) ast_transforms.ParentScopeAssigner().visit(ast) - assert ast.parent is None - assert ast.main_program.parent == None + assert not ast.parent + assert isinstance(ast, Program_Node) + assert ast.main_program is not None main_program = ast.main_program # Both executed lines @@ -42,50 +42,55 @@ def test_fortran_frontend_parent(): assert arg.parent == main_program for subroutine in ast.subroutine_definitions: - - assert subroutine.parent == None + assert not subroutine.parent assert subroutine.execution_part.parent == subroutine for execution in subroutine.execution_part.execution: assert execution.parent == subroutine + def test_fortran_frontend_module(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - test_string = """ - module test_module - implicit none - ! good enough approximation - integer, parameter :: pi = 4 - end module test_module - - PROGRAM access_test - implicit none - double precision d(4) - d(1)=0 - CALL array_access_test_function(d) - end - - SUBROUTINE array_access_test_function(d) - double precision d(4) - - d(2)=5.5 - - END SUBROUTINE array_access_test_function - """ - ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + ! good enough approximation + integer, parameter :: pi = 4 +end module lib +""").add_file(""" +program main + implicit none + double precision d(4) + d(1) = 0 + call fun(d) +end program main + +subroutine fun(d) + use lib, only: pi + implicit none + double precision d(4) + d(2) = pi +end subroutine fun +""", 'main').check_with_gfortran().get() + cfg = ParseConfig(main=sources['main.f90'], sources=sources) + _, ast = create_internal_ast(cfg) ast_transforms.ParentScopeAssigner().visit(ast) - assert ast.parent is None - assert ast.main_program.parent == None + assert not ast.parent + assert isinstance(ast, Program_Node) + assert not ast.main_program.parent + assert len(ast.modules) == 1 module = ast.modules[0] - assert module.parent == None - specification = module.specification_part.specifications[0] + assert not module.parent + + assert module.specification_part is not None + assert len(module.specification_part.symbols) == 1 + specification = module.specification_part.symbols[0] assert specification.parent == module if __name__ == "__main__": - test_fortran_frontend_parent() test_fortran_frontend_module() diff --git a/tests/fortran/pointer_removal_test.py b/tests/fortran/pointer_removal_test.py new file mode 100644 index 0000000000..3e2705c4b3 --- /dev/null +++ b/tests/fortran/pointer_removal_test.py @@ -0,0 +1,208 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +from dace.transformation.passes.lift_struct_views import LiftStructViews +from dace.transformation import pass_pipeline as ppl + +def test_fortran_frontend_ptr_assignment_removal(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + INTEGER,POINTER :: tmp + tmp=>s%a + + tmp = 13 + d(2,1) = max(1.0, tmp) + END SUBROUTINE type_in_call_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 13) + assert (a[2, 0] == 42) + + +def test_fortran_frontend_ptr_assignment_removal_array(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + REAL,POINTER :: tmp(:,:,:) + tmp=>s%w + + tmp(1,1,1) = 11.0 + d(2,1) = max(1.0, tmp(1,1,1)) + END SUBROUTINE type_in_call_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources,normalize_offsets=True) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + +def test_fortran_frontend_ptr_assignment_removal_array_assumed(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + REAL,POINTER :: tmp(:,:,:) + tmp=>s%w + + tmp(1,1,1) = 11.0 + d(2,1) = max(1.0, tmp(1,1,1)) + + CALL type_in_call_test_function2(tmp) + d(3,1) = max(1.0, tmp(2,1,1)) + + END SUBROUTINE type_in_call_test_function + + SUBROUTINE type_in_call_test_function2(tmp) + REAL,POINTER :: tmp(:,:,:) + + tmp(2,1,1) = 1410 + END SUBROUTINE type_in_call_test_function2 + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 1410) + +def test_fortran_frontend_ptr_assignment_removal_array_nested(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type4 + REAL :: w(5,5,5) + END TYPE simple_type4 + + TYPE simple_type3 + type(simple_type4):: val3 + END TYPE simple_type3 + + TYPE simple_type2 + type(simple_type3):: val + REAL :: w(5,5,5) + END TYPE simple_type2 + + TYPE simple_type + type(simple_type2) :: val1 + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + REAL,POINTER :: tmp(:,:,:) + !tmp=>s%val1%val%w + tmp=>s%val1%w + + tmp(1,1,1) = 11.0 + d(2,1) = tmp(1,1,1) + END SUBROUTINE type_in_call_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + +if __name__ == "__main__": + # pointers to non-array fields are broken + #test_fortran_frontend_ptr_assignment_removal() + test_fortran_frontend_ptr_assignment_removal_array() + # broken - no idea why + #test_fortran_frontend_ptr_assignment_removal_array_assumed() + # also broken - bug in codegen + #test_fortran_frontend_ptr_assignment_removal_array_nested() diff --git a/tests/fortran/prune_test.py b/tests/fortran/prune_test.py new file mode 100644 index 0000000000..46585d2825 --- /dev/null +++ b/tests/fortran/prune_test.py @@ -0,0 +1,147 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +def test_fortran_frontend_prune_simple(): + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(4) + CALL test_function(d, dx) + end + + SUBROUTINE test_function(d, dx) + + double precision dx(4) + double precision d(4) + + d(2) = d(1) + 3.14 + + END SUBROUTINE test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", False) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42 + 3.14) + assert (a[2] == 42) + + +def test_fortran_frontend_prune_complex(): + # Test we can detect recursively unused arguments + # Test we can change names and it does not affect pruning + # Test we can use two different ignored args in the same function + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(1) + double precision dy(1) + CALL test_function(dy, d, dx) + end + + SUBROUTINE test_function(dy, d, dx) + + double precision dx(4) + double precision d(1) + double precision dy(1) + + d(2) = d(1) + 3.14 + + CALL test_function_another(d, dx) + CALL test_function_another(d, dy) + + END SUBROUTINE test_function + + SUBROUTINE test_function_another(dx, dz) + + double precision dx(4) + double precision dz(1) + + dx(3) = dx(3) - 1 + + END SUBROUTINE test_function_another + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", True) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42 + 3.14) + assert (a[2] == 40) + +def test_fortran_frontend_prune_actual_param(): + # Test we do not remove a variable that is passed along + # but not used in the function. + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(1) + double precision dy(1) + CALL test_function(dy, d, dx) + end + + SUBROUTINE test_function(dy, d, dx) + + double precision d(4) + double precision dx(1) + double precision dy(1) + + CALL test_function_another(d, dx) + CALL test_function_another(d, dy) + + END SUBROUTINE test_function + + SUBROUTINE test_function_another(dx, dz) + + double precision dx(4) + double precision dz(1) + + dx(3) = dx(3) - 1 + + END SUBROUTINE test_function_another + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", True) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42) + assert (a[2] == 40) + +if __name__ == "__main__": + + test_fortran_frontend_prune_simple() + test_fortran_frontend_prune_complex() + test_fortran_frontend_prune_actual_param() diff --git a/tests/fortran/prune_unused_children_test.py b/tests/fortran/prune_unused_children_test.py new file mode 100644 index 0000000000..1e7d921930 --- /dev/null +++ b/tests/fortran/prune_unused_children_test.py @@ -0,0 +1,769 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict, List + +from fparser.common.readfortran import FortranStringReader +from fparser.two.Fortran2003 import Program +from fparser.two.parser import ParserFactory +from fparser.two.utils import walk + +from dace.frontend.fortran.fortran_parser import recursive_ast_improver +from dace.frontend.fortran.ast_desugaring import ENTRY_POINT_OBJECT_TYPES, find_name_of_node, prune_unused_objects +from tests.fortran.fortran_test_helper import SourceCodeBuilder + + +def parse_and_improve(sources: Dict[str, str]): + parser = ParserFactory().create(std="f2008") + assert 'main.f90' in sources + reader = FortranStringReader(sources['main.f90']) + ast = parser(reader) + ast = recursive_ast_improver(ast, sources, [], parser) + assert isinstance(ast, Program) + return ast + + +def find_entrypoint_objects_named(ast: Program, name: str) -> List[ENTRY_POINT_OBJECT_TYPES]: + objs: List[ENTRY_POINT_OBJECT_TYPES] = [] + for n in walk(ast, ENTRY_POINT_OBJECT_TYPES): + assert isinstance(n, ENTRY_POINT_OBJECT_TYPES) + if not isinstance(n.parent, Program): + continue + if find_name_of_node(n) == name: + objs.append(n) + return objs + + +def prune_from_main(ast: Program) -> Program: + return prune_unused_objects(ast, find_entrypoint_objects_named(ast, 'main')) + + +def test_minimal_no_pruning(): + """ + NOTE: We have a very similar test in `recursive_ast_improver_test.py`. + A minimal program that does not have any modules. So, `recompute_children()` should be a noop here. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + d(2) = 5.5 +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # Since there was no module, it should be the exact same AST as the corresponding test in + # `recursive_ast_improver_test.py`. + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END PROGRAM main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_toplevel_subroutine_no_pruning(): + """ + A simple program with not much to "recursively improve", but this time the subroutine is defined outside and called + from the main program. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + call fun(d) +end program main + +subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 +end subroutine fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_program_contains_subroutine_no_pruning(): + """ + A simple program with not much to "recursively improve", but this time the subroutine is defined outside and called + from the main program. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + call fun(d) +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 + end subroutine fun +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 + END SUBROUTINE fun +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_standalone_subroutine_no_pruning(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine main(d) + implicit none + double precision d(4) + d(2) = 5.5 +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_toplevel_subroutine_uses_another_module_no_pruning(): + """ + A simple program with not much to "recursively improve", but this time the subroutine is defined outside and called + from the main program. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + double precision :: val = 5.5 +end module lib +""").add_file(""" +program main + implicit none + double precision d(4) + call fun(d) +end program main + +subroutine fun(d) + use lib + implicit none + double precision d(4) + d(2) = val +end subroutine fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + DOUBLE PRECISION :: val = 5.5 +END MODULE lib +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +SUBROUTINE fun(d) + USE lib + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = val +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_uses_module_which_uses_module_no_pruning(): + """ + NOTE: We have a very similar test in `recursive_ast_improver_test.py`. + A simple program that uses modules, which in turn uses another module. The main program uses the module and calls + the subroutine. So, we should have "recursively improved" the AST by parsing that module and constructing the + dependency graph. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 + end subroutine fun +end module lib +""").add_file(""" +module lib_indirect + use lib +contains + subroutine fun_indirect(d) + implicit none + double precision d(4) + call fun(d) + end subroutine fun_indirect +end module lib_indirect +""").add_file(""" +program main + use lib_indirect, only: fun_indirect + implicit none + double precision d(4) + call fun_indirect(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +MODULE lib + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 + END SUBROUTINE fun +END MODULE lib +MODULE lib_indirect + USE lib + CONTAINS + SUBROUTINE fun_indirect(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) + END SUBROUTINE fun_indirect +END MODULE lib_indirect +PROGRAM main + USE lib_indirect, ONLY: fun_indirect + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun_indirect(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_module_contains_interface_block_no_pruning(): + """ + NOTE: We have a very similar test in `recursive_ast_improver_test.py`. + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + real function fun() + implicit none + fun = 5.5 + end function fun +end module lib +""").add_file(""" +module lib_indirect + use lib, only: fun + implicit none + interface xi + module procedure fun + end interface xi + +contains + real function fun2() + implicit none + fun2 = 4.2 + end function fun2 +end module lib_indirect +""").add_file(""" +program main + use lib_indirect, only : fun, fun2 + implicit none + + double precision d(4) + d(2) = fun() +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION fun() + IMPLICIT NONE + fun = 5.5 + END FUNCTION fun +END MODULE lib +MODULE lib_indirect + USE lib, ONLY: fun + IMPLICIT NONE + INTERFACE xi + MODULE PROCEDURE fun + END INTERFACE xi + CONTAINS + REAL FUNCTION fun2() + IMPLICIT NONE + fun2 = 4.2 + END FUNCTION fun2 +END MODULE lib_indirect +PROGRAM main + USE lib_indirect, ONLY: fun, fun2 + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun() +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_uses_module_but_prunes_unused_defs(): + """ + A simple program, but this time the subroutine is defined in a module, that also has some unused subroutine. + The main program uses the module and calls the subroutine. So, we should have "recursively improved" the AST by + parsing that module and constructing the dependency graph. Then after simplification, that unused subroutine should + be gone from the dependency graph. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 + end subroutine fun + subroutine not_fun(d) ! `main` only uses `fun`, so this should be dropped after simplification + implicit none + double precision d(4) + d(2) = 4.2 + end subroutine not_fun + integer function real_fun() ! `main` only uses `fun`, so this should be dropped after simplification + implicit none + real_fun = 4.7 + end function real_fun +end module lib +""").add_file(""" +program main + use lib, only: fun + implicit none + double precision d(4) + call fun(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 + END SUBROUTINE fun +END MODULE lib +PROGRAM main + USE lib, ONLY: fun + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_module_contains_used_and_unused_types_prunes_unused_defs(): + """ + Module has type definition that the program does not use, so it gets pruned. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + + type used_type + real :: w(5, 5, 5), z(5) + integer :: a + real :: name + end type used_type + + type dead_type + real :: w(5, 5, 5), z(5) + integer :: a + real :: name + end type dead_type +end module lib +""").add_file(""" +program main + use lib, only : used_type + implicit none + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + type(used_type) :: s + s%w(1, 1, 1) = 5.5 + d(2, 1) = 5.5 + s%w(1, 1, 1) + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + TYPE :: used_type + REAL :: w(5, 5, 5), z(5) + INTEGER :: a + REAL :: name + END TYPE used_type +END MODULE lib +PROGRAM main + USE lib, ONLY: used_type + IMPLICIT NONE + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + TYPE(used_type) :: s + s % w(1, 1, 1) = 5.5 + d(2, 1) = 5.5 + s % w(1, 1, 1) + END SUBROUTINE type_test_function +END PROGRAM main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_module_contains_used_and_unused_variables_doesnt_prune_variables(): + """ + Module has unused variables. But we don't prune variables. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + integer, parameter :: used = 1 + real, parameter :: unused = 4.2 +end module lib +""").add_file(""" +program main + use lib, only: used + implicit none + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + d(2, 1) = used + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTEGER, PARAMETER :: used = 1 + REAL, PARAMETER :: unused = 4.2 +END MODULE lib +PROGRAM main + USE lib, ONLY: used + IMPLICIT NONE + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + d(2, 1) = used + END SUBROUTINE type_test_function +END PROGRAM main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_module_contains_used_and_unused_variables_with_use_all_prunes_unused(): + """ + Module has unused variables that are pulled in with "use-all". + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + integer, parameter :: used = 1 + real, parameter :: unused = 4.2 +end module lib +""").add_file(""" +program main + use lib + implicit none + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + d(2, 1) = used + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTEGER, PARAMETER :: used = 1 + REAL, PARAMETER :: unused = 4.2 +END MODULE lib +PROGRAM main + USE lib + IMPLICIT NONE + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + d(2, 1) = used + END SUBROUTINE type_test_function +END PROGRAM main + """.strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_use_statement_multiple_doesnt_prune_variables(): + """ + We have multiple uses of the same module. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + integer, parameter :: a = 1 + real, parameter :: b = 4.2 + real, parameter :: c = -7.1 +end module lib +""").add_file(""" +program main + use lib, only: a + use lib, only: b + implicit none + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + d(1, 1) = a + d(1, 1) = b + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTEGER, PARAMETER :: a = 1 + REAL, PARAMETER :: b = 4.2 + REAL, PARAMETER :: c = - 7.1 +END MODULE lib +PROGRAM main + USE lib, ONLY: a + USE lib, ONLY: b + IMPLICIT NONE + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + d(1, 1) = a + d(1, 1) = b + END SUBROUTINE type_test_function +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_use_statement_multiple_with_useall_prunes_unused(): + """ + We have multiple uses of the same module. One of them is a "use-all". + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + integer, parameter :: a = 1 + real, parameter :: b = 4.2 + real, parameter :: c = -7.1 +end module lib +""").add_file(""" +program main + use lib + use lib, only: a + implicit none + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + d(1, 1) = a + d(1, 1) = b + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTEGER, PARAMETER :: a = 1 + REAL, PARAMETER :: b = 4.2 + REAL, PARAMETER :: c = - 7.1 +END MODULE lib +PROGRAM main + USE lib + USE lib, ONLY: a + IMPLICIT NONE + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + d(1, 1) = a + d(1, 1) = b + END SUBROUTINE type_test_function +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_subroutine_contains_function_no_pruning(): + """ + A function is defined inside a subroutine that calls it. A main program uses the top-level subroutine. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = fun2() + + contains + real function fun2() + implicit none + fun2 = 5.5 + end function fun2 + end subroutine fun +end module lib +""").add_file(""" +program main + use lib, only: fun + implicit none + + double precision d(4) + call fun(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = prune_from_main(ast) + + # `not_fun` and `real_fun` should be gone! + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun2() + CONTAINS + REAL FUNCTION fun2() + IMPLICIT NONE + fun2 = 5.5 + END FUNCTION fun2 + END SUBROUTINE fun +END MODULE lib +PROGRAM main + USE lib, ONLY: fun + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() diff --git a/tests/fortran/ranges_test.py b/tests/fortran/ranges_test.py new file mode 100644 index 0000000000..39363f7412 --- /dev/null +++ b/tests/fortran/ranges_test.py @@ -0,0 +1,535 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +""" +We test for the following patterns: +* Range 'ALL' +* selecting one element by constant +* selecting one element by variable +* selecting a subset (proper range) through constants +* selecting a subset (proper range) through variables +* ECRAD patterns (WiP) + flux_dn(:,1:i_cloud_top) = flux_dn_clear(:,1:i_cloud_top) +* Extended ECRAD pattern with different loop starting positions. +* Arrays with offsets +* Assignment with arrays that have no range expression on the right +""" + +def test_fortran_frontend_multiple_ranges_all(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + CALL multiple_ranges_function(input1, input2, res) + end + + SUBROUTINE multiple_ranges_function(input1, input2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + + res(:) = input1(:) - input2(:) + + END SUBROUTINE multiple_ranges_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_function", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size], 0, order="F", dtype=np.float64) + input2 = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i] = i + 1 + input2[i] = i + res = np.full([7], 42, order="F", dtype=np.float64) + sdfg(input1=input1, input2=input2, res=res) + print(res) + for val in res: + assert val == 1.0 + +def test_fortran_frontend_multiple_ranges_selection(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_selection + implicit none + double precision, dimension(7,2) :: input1 + double precision, dimension(7) :: res + CALL multiple_ranges_selection_function(input1, res) + end + + SUBROUTINE multiple_ranges_selection_function(input1, res) + double precision, dimension(7,2) :: input1 + double precision, dimension(7) :: res + + res(:) = input1(:, 1) - input1(:, 2) + + END SUBROUTINE multiple_ranges_selection_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_selection", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + size2 = 2 + input1 = np.full([size, size2], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i][0] = i + 1 + input1[i][1] = 0 + res = np.full([7], 42, order="F", dtype=np.float64) + sdfg(input1=input1, res=res, outside_init=False) + for idx, val in enumerate(res): + assert val == idx + 1.0 + +def test_fortran_frontend_multiple_ranges_selection_var(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_selection + implicit none + double precision, dimension(7,2) :: input1 + double precision, dimension(7) :: res + integer :: pos1 + integer :: pos2 + CALL multiple_ranges_selection_function(input1, res, pos1, pos2) + end + + SUBROUTINE multiple_ranges_selection_function(input1, res, pos1, pos2) + double precision, dimension(7,2) :: input1 + double precision, dimension(7) :: res + integer :: pos1 + integer :: pos2 + + res(:) = input1(:, pos1) - input1(:, pos2) + + END SUBROUTINE multiple_ranges_selection_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_selection", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + size2 = 2 + input1 = np.full([size, size2], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i][1] = i + 1 + input1[i][0] = 0 + res = np.full([7], 42, order="F", dtype=np.float64) + sdfg(input1=input1, res=res, pos1=2, pos2=1, outside_init=False) + for idx, val in enumerate(res): + assert val == idx + 1.0 + + sdfg(input1=input1, res=res, pos1=1, pos2=2, outside_init=False) + for idx, val in enumerate(res): + assert -val == idx + 1.0 + +def test_fortran_frontend_multiple_ranges_subset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_subset + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(3) :: res + CALL multiple_ranges_subset_function(input1, res) + end + + SUBROUTINE multiple_ranges_subset_function(input1, res) + double precision, dimension(7) :: input1 + double precision, dimension(3) :: res + + res(:) = input1(1:3) - input1(4:6) + + END SUBROUTINE multiple_ranges_subset_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_subset", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + sdfg.save('test.sdfg') + + size = 7 + input1 = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(input1=input1, res=res, outside_init=False) + for idx, val in enumerate(res): + assert val == -3.0 + +def test_fortran_frontend_multiple_ranges_subset_var(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_subset_var + implicit none + double precision, dimension(9) :: input1 + double precision, dimension(3) :: res + integer, dimension(4) :: pos + CALL multiple_ranges_subset_var_function(input1, res, pos) + end + + SUBROUTINE multiple_ranges_subset_var_function(input1, res, pos) + double precision, dimension(9) :: input1 + double precision, dimension(3) :: res + integer, dimension(4) :: pos + + res(:) = input1(pos(1):pos(2)) - input1(pos(3):pos(4)) + + END SUBROUTINE multiple_ranges_subset_var_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_subset_var", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 9 + input1 = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i] = 2 ** i + + pos = np.full([4], 0, order="F", dtype=np.int32) + pos[0] = 2 + pos[1] = 4 + pos[2] = 6 + pos[3] = 8 + + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(input1=input1, pos=pos, res=res, outside_init=False) + + for i in range(len(res)): + assert res[i] == input1[pos[0] - 1 + i] - input1[pos[2] - 1 + i] + +def test_fortran_frontend_multiple_ranges_ecrad_pattern(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad + implicit none + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(2) :: pos + CALL multiple_ranges_ecrad_function(input1, res, pos) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, res, pos) + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(2) :: pos + + res(:, pos(1):pos(2)) = input1(:, pos(1):pos(2)) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size, size], 0, order="F", dtype=np.float64) + for i in range(size): + for j in range(size): + input1[i, j] = i + 2 ** j + + pos = np.full([2], 0, order="F", dtype=np.int32) + pos[0] = 2 + pos[1] = 5 + + res = np.full([size, size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, pos=pos, res=res, outside_init=False) + + for i in range(size): + for j in range(pos[0], pos[1] + 1): + + print(i , j, res[i - 1, j - 1], input1[i - 1, j - 1]) + assert res[i - 1, j - 1] == input1[i - 1, j - 1] + +def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad + implicit none + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(6) :: pos + CALL multiple_ranges_ecrad_function(input1, res, pos) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, res, pos) + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(6) :: pos + + res(:, pos(1):pos(2)) = input1(:, pos(3):pos(4)) + input1(:, pos(5):pos(6)) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size, size], 0, order="F", dtype=np.float64) + for i in range(size): + for j in range(size): + input1[i, j] = i + 2 ** j + + pos = np.full([6], 0, order="F", dtype=np.int32) + pos[0] = 2 + pos[1] = 5 + pos[2] = 1 + pos[3] = 4 + pos[4] = 4 + pos[5] = 7 + + res = np.full([size, size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, pos=pos, res=res, outside_init=False) + + iter_1 = pos[0] + iter_2 = pos[2] + iter_3 = pos[4] + length = pos[1] - pos[0] + 1 + + for i in range(size): + for j in range(length): + assert res[i - 1, iter_1 + j - 1] == input1[i - 1, iter_2 + j - 1] + input1[i - 1, iter_3 + j - 1] + +def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex_offsets(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad_offset + implicit none + double precision, dimension(7, 21:27) :: input1 + double precision, dimension(7, 31:37) :: res + integer, dimension(6) :: pos + CALL multiple_ranges_ecrad_function(input1, res, pos) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, res, pos) + double precision, dimension(7, 21:27) :: input1 + double precision, dimension(7, 31:37) :: res + integer, dimension(6) :: pos + + res(:, pos(1):pos(2)) = input1(:, pos(3):pos(4)) + input1(:, pos(5):pos(6)) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad_offset", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size, size], 0, order="F", dtype=np.float64) + for i in range(size): + for j in range(size): + input1[i, j] = i + 2 ** j + + pos = np.full([6], 0, order="F", dtype=np.int32) + pos[0] = 2 + 30 + pos[1] = 5 + 30 + pos[2] = 1 + 20 + pos[3] = 4 + 20 + pos[4] = 4 + 20 + pos[5] = 7 + 20 + + res = np.full([size, size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, pos=pos, res=res, outside_init=False) + + iter_1 = pos[0] - 30 + iter_2 = pos[2] - 20 + iter_3 = pos[4] - 20 + length = pos[1] - pos[0] + 1 + + for i in range(size): + for j in range(length): + assert res[i - 1, iter_1 + j - 1] == input1[i - 1, iter_2 + j - 1] + input1[i - 1, iter_3 + j - 1] + +def test_fortran_frontend_array_assignment(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7, 7) :: res + integer, dimension(2) :: pos + CALL multiple_ranges_ecrad_function(input1, input2, res, pos) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, input2, res, pos) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7, 7) :: res + integer, dimension(2) :: pos + integer :: nlev + + nlev = input1(1) + + ! write 5 to column 2 + res(:, pos(1)) = nlev + + ! write input1 values to column 3 + res(:, pos(1) + 1) = input1 + + res(:, pos(1) + 2) = input1 + input2 + + res(:, pos(1) + 3) = input1 + input2(:) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size], 0, order="F", dtype=np.float64) + input2 = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i] = i + 5 + input2[i] = i + 6 + + pos = np.full([2], 0, order="F", dtype=np.int32) + pos[0] = 2 + pos[1] = 5 + + res = np.full([size, size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, input2=input2, pos=pos, res=res, outside_init=False) + + for i in range(size): + assert res[i, 1] == input1[0] + assert res[i, 2] == input1[i] + assert res[i, 3] == input1[i] + input2[i] + assert res[i, 4] == input1[i] + input2[i] + +def test_fortran_frontend_multiple_ranges_ecrad_bug(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad_bug + implicit none + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(4) :: pos + CALL multiple_ranges_ecrad_function(input1, res, pos) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, res, pos) + double precision, dimension(7, 7) :: input1 + double precision, dimension(7, 7) :: res + integer, dimension(4) :: pos + integer :: nval + + nval = pos(1) + + res(nval, pos(1):pos(2)) = input1(nval, pos(3):pos(4)) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad_bug", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size, size], 0, order="F", dtype=np.float64) + for i in range(size): + for j in range(size): + input1[i, j] = i + 2 ** j + + pos = np.full([4], 0, order="F", dtype=np.int32) + pos[0] = 2 + pos[1] = 5 + pos[2] = 1 + pos[3] = 4 + + res = np.full([size, size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, pos=pos, res=res, outside_init=False) + + iter_1 = pos[0] + iter_2 = pos[2] + length = pos[1] - pos[0] + 1 + + i = pos[0] - 1 + for j in range(length): + + assert res[i, iter_1 - 1] == input1[i, iter_2 - 1] + iter_1 += 1 + iter_2 += 1 + +def test_fortran_frontend_ranges_array_bug(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM multiple_ranges_ecrad_bug + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: res + CALL multiple_ranges_ecrad_function(input1, res) + end + + SUBROUTINE multiple_ranges_ecrad_function(input1, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: res + + res(:) = input1(2) * input1(:) + + END SUBROUTINE multiple_ranges_ecrad_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "multiple_ranges_ecrad_bug", True) + #sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + input1 = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + input1[i] = i + 2 + + res = np.full([size], 42, order="F", dtype=np.float64) + sdfg(input1=input1, res=res, outside_init=False) + + assert np.all(res == input1 * input1[1]) + + +if __name__ == "__main__": + + test_fortran_frontend_multiple_ranges_all() + test_fortran_frontend_multiple_ranges_selection() + test_fortran_frontend_multiple_ranges_selection_var() + test_fortran_frontend_multiple_ranges_subset() + test_fortran_frontend_multiple_ranges_subset_var() + test_fortran_frontend_multiple_ranges_ecrad_pattern() + test_fortran_frontend_multiple_ranges_ecrad_pattern_complex() + test_fortran_frontend_multiple_ranges_ecrad_pattern_complex_offsets() + test_fortran_frontend_array_assignment() + test_fortran_frontend_multiple_ranges_ecrad_bug() + test_fortran_frontend_ranges_array_bug() diff --git a/tests/fortran/recursive_ast_improver_test.py b/tests/fortran/recursive_ast_improver_test.py new file mode 100644 index 0000000000..ef9fbdf5bc --- /dev/null +++ b/tests/fortran/recursive_ast_improver_test.py @@ -0,0 +1,731 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict + +from fparser.common.readfortran import FortranStringReader +from fparser.two.Fortran2003 import Program +from fparser.two.parser import ParserFactory + +from dace.frontend.fortran.fortran_parser import recursive_ast_improver +from dace.frontend.fortran.ast_desugaring import deconstruct_procedure_calls +from tests.fortran.fortran_test_helper import SourceCodeBuilder + + +def parse_and_improve(sources: Dict[str, str]): + parser = ParserFactory().create(std="f2008") + assert 'main.f90' in sources + reader = FortranStringReader(sources['main.f90']) + ast = parser(reader) + ast = recursive_ast_improver(ast, sources, [], parser) + ast = deconstruct_procedure_calls(ast) + assert isinstance(ast, Program) + + return ast + + +def test_minimal(): + """ + A minimal program with not much to "recursively improve". + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + d(2) = 5.5 +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_toplevel_subroutine(): + """ + A simple program with not much to "recursively improve", but this time the subroutine is defined outside and called + from the main program. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + call fun(d) +end program main + +subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 +end subroutine fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_standalone_subroutine(): + """ + A standalone subroutine, with no program or module in sight. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 +end subroutine fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_program_contains_subroutine(): + """ + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + double precision d(4) + call fun(d) +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = fun2() + end subroutine fun + real function fun2() + implicit none + fun2 = 5.5 + end function fun2 +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun2() + END SUBROUTINE fun + REAL FUNCTION fun2() + IMPLICIT NONE + fun2 = 5.5 + END FUNCTION fun2 +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_subroutine_contains_function(): + """ + A function is defined inside a subroutine that calls it. There is no main program. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = fun2() + + contains + real function fun2() + implicit none + fun2 = 5.5 + end function fun2 + end subroutine fun +end module lib +""").add_file(""" +program main + use lib, only: fun + implicit none + + double precision d(4) + call fun(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun2() + CONTAINS + REAL FUNCTION fun2() + IMPLICIT NONE + fun2 = 5.5 + END FUNCTION fun2 + END SUBROUTINE fun +END MODULE lib +PROGRAM main + USE lib, ONLY: fun + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_program_contains_interface_block(): + """ + The program contains interface blocks. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + + ! We can have an interface with no name + interface + real function fun() + implicit none + end function fun + end interface + + ! We can even have multiple interfaces with no name + interface + real function fun2() + implicit none + end function fun2 + end interface + + double precision d(4) + d(2) = fun() +end program main + +real function fun() + implicit none + fun = 5.5 +end function fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + INTERFACE + REAL FUNCTION fun() + IMPLICIT NONE + END FUNCTION fun + END INTERFACE + INTERFACE + REAL FUNCTION fun2() + IMPLICIT NONE + END FUNCTION fun2 + END INTERFACE + DOUBLE PRECISION :: d(4) + d(2) = fun() +END PROGRAM main +REAL FUNCTION fun() + IMPLICIT NONE + fun = 5.5 +END FUNCTION fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_program_contains_interface_block_with_useall(): + """ + A module contains interface block, that relies on an implementation provided by a top-level definitions. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + interface + real function fun() + implicit none + end function fun + end interface +contains + real function fun2() + fun2 = fun() + end function fun2 +end module lib +""").add_file(""" +program main + use lib + use lib, only: fun2 + implicit none + + double precision d(4) + d(2) = fun2() +end program main + +real function fun() + implicit none + fun = 5.5 +end function fun +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + INTERFACE + REAL FUNCTION fun() + IMPLICIT NONE + END FUNCTION fun + END INTERFACE + CONTAINS + REAL FUNCTION fun2() + fun2 = fun() + END FUNCTION fun2 +END MODULE lib +PROGRAM main + USE lib + USE lib, ONLY: fun2 + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun2() +END PROGRAM main +REAL FUNCTION fun() + IMPLICIT NONE + fun = 5.5 +END FUNCTION fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_uses_module(): + """ + A simple program, but this time the subroutine is defined in a module. The main program uses the module and calls + the subroutine. So, we should have "recursively improved" the AST by parsing that module and constructing the + dependency graph. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 + end subroutine fun +end module lib +""").add_file(""" +program main + use lib + implicit none + double precision d(4) + call fun(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 + END SUBROUTINE fun +END MODULE lib +PROGRAM main + USE lib + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_uses_module_which_uses_module(): + """ + A simple program, but this time the subroutine is defined in a module. The main program uses the module and calls + the subroutine. So, we should have "recursively improved" the AST by parsing that module and constructing the + dependency graph. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib +contains + subroutine fun(d) + implicit none + double precision d(4) + d(2) = 5.5 + end subroutine fun +end module lib +""").add_file(""" +module lib_indirect + use lib +contains + subroutine fun_indirect(d) + implicit none + double precision d(4) + call fun(d) + end subroutine fun_indirect +end module lib_indirect +""").add_file(""" +program main + use lib_indirect, only: fun_indirect + implicit none + double precision d(4) + call fun_indirect(d) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + CONTAINS + SUBROUTINE fun(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = 5.5 + END SUBROUTINE fun +END MODULE lib +MODULE lib_indirect + USE lib + CONTAINS + SUBROUTINE fun_indirect(d) + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun(d) + END SUBROUTINE fun_indirect +END MODULE lib_indirect +PROGRAM main + USE lib_indirect, ONLY: fun_indirect + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + CALL fun_indirect(d) +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_interface_block_contains_module_procedure(): + """ + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib +implicit none +contains + real function fun() + implicit none + fun = 5.5 + end function fun +end module lib +""").add_file(""" +program main + use lib + implicit none + + interface xi + module procedure fun + end interface xi + + double precision d(4) + d(2) = fun() +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION fun() + IMPLICIT NONE + fun = 5.5 + END FUNCTION fun +END MODULE lib +PROGRAM main + USE lib + IMPLICIT NONE + INTERFACE xi + MODULE PROCEDURE fun + END INTERFACE xi + DOUBLE PRECISION :: d(4) + d(2) = fun() +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_module_contains_interface_block(): + """ + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none +contains + real function fun() + implicit none + fun = 5.5 + end function fun +end module lib +""").add_file(""" +module lib_indirect + use lib, only: fun + implicit none + interface xi + module procedure fun + end interface xi + +contains + real function fun2() + implicit none + fun2 = 4.2 + end function fun2 +end module lib_indirect +""").add_file(""" +program main + use lib_indirect, only : fun, fun2 + implicit none + + double precision d(4) + d(2) = fun() +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + CONTAINS + REAL FUNCTION fun() + IMPLICIT NONE + fun = 5.5 + END FUNCTION fun +END MODULE lib +MODULE lib_indirect + USE lib, ONLY: fun + IMPLICIT NONE + INTERFACE xi + MODULE PROCEDURE fun + END INTERFACE xi + CONTAINS + REAL FUNCTION fun2() + IMPLICIT NONE + fun2 = 4.2 + END FUNCTION fun2 +END MODULE lib_indirect +PROGRAM main + USE lib_indirect, ONLY: fun, fun2 + IMPLICIT NONE + DOUBLE PRECISION :: d(4) + d(2) = fun() +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_program_contains_type(): + """ + A function is defined inside a subroutine that calls it. A main program uses the top-level subroutine. + """ + sources, main = SourceCodeBuilder().add_file(""" +program main + implicit none + type simple_type + real :: w(5, 5, 5), z(5) + integer :: a + real :: name + end type simple_type + + real :: d(5, 5) + call type_test_function(d) + +contains + + subroutine type_test_function(d) + real d(5, 5) + type(simple_type) :: s + s%w(1, 1, 1) = 5.5 + d(2, 1) = 5.5 + s%w(1, 1, 1) + end subroutine type_test_function +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + TYPE :: simple_type + REAL :: w(5, 5, 5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + REAL :: d(5, 5) + CALL type_test_function(d) + CONTAINS + SUBROUTINE type_test_function(d) + REAL :: d(5, 5) + TYPE(simple_type) :: s + s % w(1, 1, 1) = 5.5 + d(2, 1) = 5.5 + s % w(1, 1, 1) + END SUBROUTINE type_test_function +END PROGRAM main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_floaters_are_brought_in(): + """ + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +subroutine fun(z) + implicit none + real, intent(out) :: z + z = 5.5 +end subroutine fun +""", 'floater').add_file(""" +program main + implicit none + + interface + subroutine fun(z) + implicit none + real, intent(out) :: z + end subroutine fun + end interface + + real d(4) + call fun(d(2)) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +PROGRAM main + IMPLICIT NONE + INTERFACE + SUBROUTINE fun(z) + IMPLICIT NONE + REAL, INTENT(OUT) :: z + END SUBROUTINE fun + END INTERFACE + REAL :: d(4) + CALL fun(d(2)) +END PROGRAM main +SUBROUTINE fun(z) + IMPLICIT NONE + REAL, INTENT(OUT) :: z + z = 5.5 +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + +def test_floaters_can_bring_in_more_modules(): + """ + The same simple program, but this time the subroutine is defined inside the main program that calls it. + """ + sources, main = SourceCodeBuilder().add_file(""" +module lib + implicit none + real, parameter :: zzz = 5.5 +end module lib +subroutine fun(z) + use lib + implicit none + real, intent(out) :: z + z = zzz +end subroutine fun +""", 'floater').add_file(""" +program main + implicit none + + interface + subroutine fun(z) + implicit none + real, intent(out) :: z + end subroutine fun + end interface + + real d(4) + call fun(d(2)) +end program main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + + got = ast.tofortran() + want = """ +MODULE lib + IMPLICIT NONE + REAL, PARAMETER :: zzz = 5.5 +END MODULE lib +PROGRAM main + IMPLICIT NONE + INTERFACE + SUBROUTINE fun(z) + IMPLICIT NONE + REAL, INTENT(OUT) :: z + END SUBROUTINE fun + END INTERFACE + REAL :: d(4) + CALL fun(d(2)) +END PROGRAM main +SUBROUTINE fun(z) + USE lib + IMPLICIT NONE + REAL, INTENT(OUT) :: z + z = zzz +END SUBROUTINE fun +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() diff --git a/tests/fortran/rename_test.py b/tests/fortran/rename_test.py new file mode 100644 index 0000000000..aa1576cc5b --- /dev/null +++ b/tests/fortran/rename_test.py @@ -0,0 +1,70 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_rename(): + """ + Tests that the Fortran frontend can parse complex initializations. + """ + test_string = """ + PROGRAM rename_test + implicit none + USE rename_test_module_subroutine, ONLY: rename_test_function + double precision d(4) + CALL rename_test_function(d) + end + + + """ + sources={} + sources["rename_test"]=test_string + sources["rename_test_module_subroutine.f90"]=""" + MODULE rename_test_module_subroutine + CONTAINS + SUBROUTINE rename_test_function(d) + USE rename_test_module, ONLY: ik4=>i4 + integer(ik4) :: i + + i=4 + d(2)=5.5 +i + + END SUBROUTINE rename_test_function + END MODULE rename_test_module_subroutine + """ + sources["rename_test_module.f90"]=""" + MODULE rename_test_module + IMPLICIT NONE + INTEGER, PARAMETER :: pi4 = 9 + INTEGER, PARAMETER :: i4 = SELECTED_INT_KIND(pi4) + END MODULE rename_test_module + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "rename_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0] == 42) + assert (a[1] == 9.5) + assert (a[2] == 42) + + + +if __name__ == "__main__": + + test_fortran_frontend_rename() diff --git a/tests/fortran/scope_arrays_test.py b/tests/fortran/scope_arrays_test.py index 0eb0cf44b2..5dd5b806a8 100644 --- a/tests/fortran/scope_arrays_test.py +++ b/tests/fortran/scope_arrays_test.py @@ -30,7 +30,7 @@ def test_fortran_frontend_parent(): ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") ast_transforms.ParentScopeAssigner().visit(ast) - visitor = ast_transforms.ScopeVarsDeclarations() + visitor = ast_transforms.ScopeVarsDeclarations(ast) visitor.visit(ast) for var in ['d', 'arr', 'arr3']: diff --git a/tests/fortran/struct_test.py b/tests/fortran/struct_test.py new file mode 100644 index 0000000000..93606f1964 --- /dev/null +++ b/tests/fortran/struct_test.py @@ -0,0 +1,115 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_struct(): + test_string = """ + PROGRAM struct_test_range + implicit none + + type test_type + integer :: start + integer :: end + end type + + integer, dimension(6) :: res + integer :: startidx + integer :: endidx + CALL struct_test_range_test_function(res, startidx, endidx) + end + + SUBROUTINE struct_test_range_test_function(res, startidx, endidx) + integer, dimension(6) :: res + integer :: startidx + integer :: endidx + type(test_type) :: indices + + indices%start=startidx + indices%end=endidx + + CALL struct_test_range2_test_function(res, indices) + + END SUBROUTINE struct_test_range_test_function + + SUBROUTINE struct_test_range2_test_function(res, idx) + integer, dimension(6) :: res + type(test_type) :: idx + + res(idx%start:idx%end) = 42 + + END SUBROUTINE struct_test_range2_test_function + """ + sources={} + sdfg = fortran_parser.create_sdfg_from_string(test_string, "res", False, sources=sources) + sdfg.save('before.sdfg') + sdfg.simplify(verbose=True) + sdfg.save('after.sdfg') + sdfg.compile() + + size = 6 + res = np.full([size], 42, order="F", dtype=np.int32) + res[:] = 0 + sdfg(res=res, start=2, end=5) + print(res) + +def test_fortran_struct_lhs(): + test_string = """ + PROGRAM struct_test_range + implicit none + + type test_type + integer, dimension(6) :: res + integer :: start + integer :: end + end type + + type test_type2 + type(test_type) :: var + end type + + integer, dimension(6) :: res + integer :: start + integer :: end + CALL struct_test_range_test_function(res, start, end) + end + + SUBROUTINE struct_test_range_test_function(res, start, end) + integer, dimension(6) :: res + integer :: start + integer :: end + type(test_type) :: indices + type(test_type2) :: val + + indices = test_type(res, start, end) + val = test_type2(indices) + + CALL struct_test_range2_test_function(val) + + END SUBROUTINE struct_test_range_test_function + + SUBROUTINE struct_test_range2_test_function(idx) + type(test_type2) :: idx + + idx%var%res(idx%var%start:idx%var%end) = 42 + + END SUBROUTINE struct_test_range2_test_function + """ + sources={} + sdfg = fortran_parser.create_sdfg_from_string(test_string, "res", False, sources=sources) + sdfg.save('before.sdfg') + sdfg.simplify(verbose=True) + sdfg.save('after.sdfg') + sdfg.compile() + + size = 6 + res = np.full([size], 42, order="F", dtype=np.int32) + res[:] = 0 + sdfg(res=res, start=2, end=5) + print(res) + +if __name__ == "__main__": + test_fortran_struct() + test_fortran_struct_lhs() diff --git a/tests/fortran/tasklet_test.py b/tests/fortran/tasklet_test.py new file mode 100644 index 0000000000..263c49b922 --- /dev/null +++ b/tests/fortran/tasklet_test.py @@ -0,0 +1,47 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_tasklet(): + test_string = """ + PROGRAM tasklet + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL tasklet_test_function(d,res) + end + + SUBROUTINE tasklet_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + real :: temp + + + integer :: i + i=1 + temp = 88 + d(1)=d(1)*2 + temp = MIN(d(i), temp) + res(1) = temp + 10 + + END SUBROUTINE tasklet_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "tasklet", normalize_offsets=True) + sdfg.view() + sdfg.simplify(verbose=True) + + sdfg.compile() + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [94, 42]) + + +if __name__ == "__main__": + + test_fortran_frontend_tasklet() diff --git a/tests/fortran/type_array_test.py b/tests/fortran/type_array_test.py new file mode 100644 index 0000000000..e54846c096 --- /dev/null +++ b/tests/fortran/type_array_test.py @@ -0,0 +1,224 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +from dace.transformation.passes.lift_struct_views import LiftStructViews +from dace.transformation import pass_pipeline as ppl + +def test_fortran_frontend_type_array(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_array_test + implicit none + + + TYPE simple_type + REAL,POINTER :: w(5,5) + END TYPE simple_type + + TYPE simple_type2 + type(simple_type) :: pprog(10) + END TYPE simple_type2 + + REAL :: d(5,5) + CALL type_array_test_function(d) + print *, d(1,1) + end + + SUBROUTINE type_array_test_function(d) + REAL :: d(5,5) + TYPE(simple_type2) :: p_prog + + CALL type_array_test_f2(p_prog%pprog(1)) + d(1,1) = p_prog%pprog(1)%w(1,1) + END SUBROUTINE type_array_test_function + + SUBROUTINE type_array_test_f2(stuff) + TYPE(simple_type) :: stuff + CALL deepest(stuff%w) + + END SUBROUTINE type_array_test_f2 + + SUBROUTINE deepest(my_arr) + REAL :: my_arr(:,:) + + my_arr(1,1) = 42 + END SUBROUTINE deepest + + """ + sources={} + sources["type_array_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_array_test",sources=sources, normalize_offsets=True) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + +def test_fortran_frontend_type2_array(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type2_array_test + implicit none + + + TYPE simple_type + REAL,ALLOCATABLE :: w(:,:) + END TYPE simple_type + + TYPE simple_type2 + type(simple_type) :: pprog + END TYPE simple_type2 + + TYPE(simple_type2) :: p + REAL :: d(5,5) + CALL type2_array_test_function(d,p) + print *, d(1,1) + end + + SUBROUTINE type2_array_test_function(d,p_prog) + REAL :: d(5,5) + TYPE(simple_type2) :: p_prog + + CALL type2_array_test_f2(d,p_prog) + + END SUBROUTINE type2_array_test_function + + SUBROUTINE type2_array_test_f2(d,stuff) + TYPE(simple_type2) :: stuff + REAL :: d(5,5) + CALL deepest(stuff,d) + + END SUBROUTINE type2_array_test_f2 + + SUBROUTINE deepest(my_arr,d) + REAL :: d(5,5) + TYPE(simple_type2) :: my_arr + REAL, DIMENSION(:,:), POINTER, CONTIGUOUS :: my_arr2 + + + my_arr2=>my_arr%pprog%w + + d(1,1)=my_arr2(1,1) + END SUBROUTINE deepest + + """ + sources={} + sources["type2_array_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type2_array_test",sources=sources, normalize_offsets=True) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + +def test_fortran_frontend_type3_array(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type3_array_test + implicit none + + + TYPE simple_type + REAL,ALLOCATABLE :: w(:,:) + END TYPE simple_type + + + TYPE bla_type + REAL,ALLOCATABLE :: a + END TYPE bla_type + + TYPE metrics_type + REAL,ALLOCATABLE :: b + END TYPE metrics_type + + TYPE simple_type2 + type(simple_type) :: pprog + type(bla_type) :: diag + type(metrics_type):: metrics + END TYPE simple_type2 + + TYPE(simple_type2) :: p + REAL :: d(5,5) + CALL type3_array_test_function(d,p) + print *, d(1,1) + end + + SUBROUTINE type3_array_test_function(d,p_prog) + REAL :: d(5,5) + TYPE(simple_type2) :: p_prog + integer :: istep + istep=1 + + DO istep=1,2 + if (istep==1) then + CALL type2_array_test_f2(d,p_prog%pprog, p_prog%diag, p_prog%metrics,istep) + else + CALL type2_array_test_f2(d,p_prog%pprog, p_prog%diag, p_prog%metrics,istep) + endif + ENDDO + + END SUBROUTINE type3_array_test_function + + SUBROUTINE type2_array_test_f2(d,stuff,diag,metrics,istep) + TYPE(simple_type) :: stuff + TYPE(bla_type) :: diag + TYPE(metrics_type) :: metrics + INTEGER :: istep + REAL :: d(5,5) + diag%a=1 + metrics%b=2 + d(1,1)=stuff%w(1,1)+diag%a+metrics%b + if (istep==1) then + CALL deepest(stuff,d) + endif + + END SUBROUTINE type2_array_test_f2 + + SUBROUTINE deepest(my_arr,d) + REAL :: d(5,5) + TYPE(simple_type) :: my_arr + REAL, DIMENSION(:,:), POINTER, CONTIGUOUS :: my_arr2 + + + my_arr2=>my_arr%w + + d(1,1)=my_arr2(1,1) + END SUBROUTINE deepest + + """ + sources={} + sources["type3_array_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type3_array_test",sources=sources, normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + #a = np.full([5, 5], 42, order="F", dtype=np.float32) + #sdfg(d=a) + #print(a) + + + + +if __name__ == "__main__": + + #test_fortran_frontend_type_array() + test_fortran_frontend_type3_array() diff --git a/tests/fortran/type_test.py b/tests/fortran/type_test.py new file mode 100644 index 0000000000..6cdc0c06b4 --- /dev/null +++ b/tests/fortran/type_test.py @@ -0,0 +1,658 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +from dace.transformation.passes.lift_struct_views import LiftStructViews +from dace.transformation import pass_pipeline as ppl + + +def test_fortran_frontend_basic_type(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_test_function(d) + end + + SUBROUTINE type_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + s%w(1,1,1) = 5.5 + d(2,1) = 5.5 + s%w(1,1,1) + END SUBROUTINE type_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + + +def test_fortran_frontend_basic_type2(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type2_test + implicit none + + TYPE simple_type + REAL:: w(5,5,5),z(5) + INTEGER:: a + END TYPE simple_type + + TYPE comlex_type + TYPE(simple_type):: s + REAL:: b + END TYPE comlex_type + + TYPE meta_type + TYPE(comlex_type):: cc + REAL:: omega + END TYPE meta_type + + REAL :: d(5,5) + CALL type2_test_function(d) + end + + SUBROUTINE type2_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s(3) + TYPE(comlex_type) :: c + TYPE(meta_type) :: m + + c%b=1.0 + c%s%w(1,1,1)=5.5 + m%cc%s%a=17 + s(1)%w(1,1,1)=5.5+c%b + d(2,1)=c%s%w(1,1,1)+s(1)%w(1,1,1) + + END SUBROUTINE type2_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type2_test") + sdfg.validate() + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + +def test_fortran_frontend_type_symbol(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_symbol_test + implicit none + + TYPE simple_type + REAL:: z(5) + INTEGER:: a + END TYPE simple_type + + + REAL :: d(5,5) + CALL type_symbol_test_function(d) + end + + SUBROUTINE type_symbol_test_function(d) + TYPE(simple_type) :: st + REAL :: d(5,5) + st%a=10 + CALL internal_function(d,st) + + END SUBROUTINE type_symbol_test_function + + + SUBROUTINE internal_function(d,st) + REAL d(5,5) + TYPE(simple_type) :: st + REAL bob(st%a) + bob(1)=5.5 + d(2,1)=2*bob(1) + + END SUBROUTINE internal_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_symbol_test",sources={"type_symbol_test":test_string}) + sdfg.validate() + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + +def test_fortran_frontend_type_pardecl(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_pardecl_test + implicit none + + TYPE simple_type + REAL:: z(5,5,5) + INTEGER:: a + END TYPE simple_type + + + REAL :: d(5,5) + CALL type_pardecl_test_function(d) + end + + SUBROUTINE type_pardecl_test_function(d) + TYPE(simple_type) :: st + REAL :: d(5,5) + st%a=10 + CALL internal_function(d,st) + + END SUBROUTINE type_pardecl_test_function + + + SUBROUTINE internal_function(d,st) + REAL d(5,5) + TYPE(simple_type) :: st + REAL bob(st%a) + INTEGER, PARAMETER :: n=5 + REAL BOB2(n) + bob(1)=5.5 + bob2(1)=5.5 + st%z(1,:,2:3)=bob(1) + d(2,1)=bob(1)+bob2 + + END SUBROUTINE internal_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_pardecl_test",sources={"type_pardecl_test":test_string}) + sdfg.validate() + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + +def test_fortran_frontend_type_struct(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_struct_test + implicit none + + TYPE simple_type + REAL:: z(5,5,5) + INTEGER:: a + REAL :: unkown(:) + !INTEGER :: unkown_size + END TYPE simple_type + + + REAL :: d(5,5) + CALL type_struct_test_function(d) + end + + SUBROUTINE type_struct_test_function(d) + TYPE(simple_type) :: st + REAL :: d(5,5) + st%a=10 + CALL internal_function(d,st) + + END SUBROUTINE type_struct_test_function + + + SUBROUTINE internal_function(d,st) + st.a.shape=[st.a_size] + REAL d(5,5) + TYPE(simple_type) :: st + REAL bob(st%a) + INTEGER, PARAMETER :: n=5 + REAL BOB2(n) + bob(1)=5.5 + bob2(1)=5.5 + st%z(1,:,2:3)=bob(1) + d(2,1)=bob(1)+bob2(1) + + END SUBROUTINE internal_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_struct_test",sources={"type_struct_test":test_string}) + sdfg.validate() + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + +def test_fortran_frontend_circular_type(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_test + implicit none + + + type a_t + real :: w(5,5,5) + type(b_t), pointer :: b + end type a_t + + type b_t + type(a_t) :: a + integer :: x + end type b_t + + type c_t + type(d_t),pointer :: ab + integer :: xz + end type c_t + + type d_t + type(c_t) :: ac + integer :: xy + end type d_t + + REAL :: d(5,5) + + CALL circular_type_test_function(d) + end + + SUBROUTINE circular_type_test_function(d) + REAL d(5,5) + TYPE(a_t) :: s + TYPE(b_t) :: b(3) + + s%w(1,1,1)=5.5 + !s%b=>b(1) + !s%b%a=>s + b(1)%x=1 + d(2,1)=5.5+s%w(1,1,1) + + END SUBROUTINE circular_type_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test") + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + + +def test_fortran_frontend_type_in_call(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + REAL,POINTER :: tmp(:,:,:) + tmp=>s%w + tmp(1,1,1) = 11.0 + d(2,1) = max(1.0, tmp(1,1,1)) + END SUBROUTINE type_in_call_test_function + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + +def test_fortran_frontend_type_array(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type3 + INTEGER :: a + END TYPE simple_type3 + + TYPE simple_type2 + type(simple_type3) :: w(7:12,8:13) + END TYPE simple_type2 + + TYPE simple_type + type(simple_type2) :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL :: d(5,5) + TYPE(simple_type) :: s + + CALL type_in_call_test_function2(s) + d(1,1) = s%name%w(8,10)%a + END SUBROUTINE type_in_call_test_function + + SUBROUTINE type_in_call_test_function2(s) + TYPE(simple_type) :: s + + s%name%w(8,10)%a = 42 + END SUBROUTINE type_in_call_test_function2 + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources, normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.save('test.sdfg') + sdfg.compile() + + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + +def test_fortran_frontend_type_array2(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_in_call_test + implicit none + + TYPE simple_type3 + INTEGER :: a + END TYPE simple_type3 + + TYPE simple_type2 + type(simple_type3) :: w(7:12,8:13) + integer :: wx(7:12,8:13) + END TYPE simple_type2 + + TYPE simple_type + type(simple_type2) :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_in_call_test_function(d) + end + + SUBROUTINE type_in_call_test_function(d) + REAL :: d(5,5) + integer :: x(3,3,3) + TYPE(simple_type) :: s + + CALL type_in_call_test_function2(s,x) + !d(1,1) = s%name%w(8, x(3,3,3))%a + d(1,2) = s%name%wx(8, x(3,3,3)) + END SUBROUTINE type_in_call_test_function + + SUBROUTINE type_in_call_test_function2(s,x) + TYPE(simple_type) :: s + integer :: x(3,3,3) + + x(3,3,3) = 10 + !s%name%w(8,x(3,3,3))%a = 42 + s%name%wx(8,x(3,3,3)) = 43 + END SUBROUTINE type_in_call_test_function2 + """ + sources={} + sources["type_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources, normalize_offsets=True) + sdfg.save("before.sdfg") + sdfg.simplify(verbose=True) + sdfg.save("after.sdfg") + sdfg.compile() + + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + +def test_fortran_frontend_type_pointer(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_pointer_test + implicit none + + TYPE simple_type + REAL :: w(5,5,5), z(5) + INTEGER :: a + REAL :: name + END TYPE simple_type + + REAL :: d(5,5) + CALL type_pointer_test_function(d) + end + + SUBROUTINE type_pointer_test_function(d) + REAL d(5,5) + TYPE(simple_type) :: s + REAL, DIMENSION(:,:,:), POINTER :: tmp + tmp=>s%w + tmp(1,1,1) = 11.0 + d(2,1) = max(1.0, tmp(1,1,1)) + END SUBROUTINE type_pointer_test_function + """ + sources={} + sources["type_pointer_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_pointer_test",sources=sources) + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + +def test_fortran_frontend_type_arg(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_arg_test + implicit none + + + TYPE simple_type + REAL, POINTER, CONTIGUOUS :: w(:,:) + END TYPE simple_type + + TYPE simple_type2 + type(simple_type), allocatable :: pprog(:) + END TYPE simple_type2 + + REAL :: d(5,5) + CALL type_arg_test_function(d) + print *, d(1,1) + end + + SUBROUTINE type_arg_test_function(d) + REAL :: d(5,5) + TYPE(simple_type2) :: p_prog + + CALL type_arg_test_f2(p_prog%pprog(1)) + d(1,1) = p_prog%pprog(1)%w(1,1) + END SUBROUTINE type_arg_test_function + + SUBROUTINE type_arg_test_f2(stuff) + TYPE(simple_type) :: stuff + CALL deepest(stuff%w) + + END SUBROUTINE type_arg_test_f2 + + SUBROUTINE deepest(my_arr) + REAL :: my_arr(:,:) + + my_arr(1,1) = 42 + END SUBROUTINE deepest + + """ + sources={} + sources["type_arg_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_arg_test",sources=sources, normalize_offsets=True) + sdfg.view() + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + + + +def test_fortran_frontend_type_arg2(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_arg2_test + implicit none + + + TYPE simple_type + REAL :: w(5,5) + END TYPE simple_type + + TYPE simple_type2 + type(simple_type) :: pprog(10) + END TYPE simple_type2 + + REAL :: d(5,5) + CALL type_arg2_test_function(d) + print *, d(1,1) + end + + SUBROUTINE type_arg2_test_function(d) + REAL :: d(5,5) + TYPE(simple_type2) :: p_prog + integer :: i + i=1 + !p_prog%pprog(1)%w(1,1) = 5.5 + CALL deepest(p_prog%pprog(i)%w,d) + + END SUBROUTINE type_arg2_test_function + + SUBROUTINE deepest(my_arr,d) + REAL :: my_arr(:,:) + REAL :: d(5,5) + my_arr(1,1) = 5.5 + d(1,1) = my_arr(1,1) + END SUBROUTINE deepest + + """ + sources={} + sources["type_arg2_test"]=test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_arg2_test",sources=sources, normalize_offsets=True) + sdfg.save("before.sdfg") + sdfg.simplify(verbose=True) + a = np.full([5, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + print(a) + + + +def test_fortran_frontend_type_view(): + """ + Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. + """ + test_string = """ + PROGRAM type_view_test + implicit none + + TYPE simple_type + REAL:: z(:,:) + INTEGER:: a + END TYPE simple_type + TYPE(simple_type) :: st + + REAL :: d(5,5) + CALL type_view_test_function(d,st) + end + + SUBROUTINE type_view_test_function(d,st) + TYPE(simple_type) :: st + REAL :: d(5,5) + st%z(1,1)=5.5 + CALL internal_function(d,st%z) + + END SUBROUTINE type_view_test_function + + + SUBROUTINE internal_function(d,sta) + REAL d(5,5) + REAL sta(:,:) + d(2,1)=2*sta(1,1) + + END SUBROUTINE internal_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_view_test",sources={"type_view_test":test_string},normalize_offsets=True) + sdfg.validate() + sdfg.simplify(verbose=True) + a = np.full([4, 5], 42, order="F", dtype=np.float32) + sdfg(d=a) + assert (a[0, 0] == 42) + assert (a[1, 0] == 11) + assert (a[2, 0] == 42) + + +if __name__ == "__main__": + #test_fortran_frontend_basic_type() + #test_fortran_frontend_basic_type2() + #test_fortran_frontend_type_symbol() + #test_fortran_frontend_type_pardecl() + #test_fortran_frontend_type_struct() + #test_fortran_frontend_circular_type() + #test_fortran_frontend_type_in_call() + #test_fortran_frontend_type_array() + #test_fortran_frontend_type_array2() + #test_fortran_frontend_type_pointer() + #test_fortran_frontend_type_arg() + #test_fortran_frontend_type_view() + test_fortran_frontend_type_arg2() diff --git a/tests/fortran/while_test.py b/tests/fortran/while_test.py new file mode 100644 index 0000000000..96a43efef7 --- /dev/null +++ b/tests/fortran/while_test.py @@ -0,0 +1,45 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_while(): + test_string = """ + PROGRAM while + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL while_test_function(d,res) + end + + SUBROUTINE while_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + + + integer :: i + i=0 + res(1)=d(1)*2 + do while (i<10) + res(1)=res(1)+1 + i=i+1 + end do + + END SUBROUTINE while_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "while", normalize_offsets=True) + sdfg.simplify(verbose=True) + sdfg.compile() + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [94, 42]) + + +if __name__ == "__main__": + + test_fortran_frontend_while() From 6b4fa3a8929c7ad2a8ecd8548678d37aaf8dfa55 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 10:16:48 +0100 Subject: [PATCH 03/12] Formatting and Python 3.7 typing conforming --- dace/frontend/fortran/ast_components.py | 136 +-- dace/frontend/fortran/ast_desugaring.py | 158 ++-- dace/frontend/fortran/ast_internal_classes.py | 283 +++--- dace/frontend/fortran/ast_transforms.py | 860 +++++++++--------- dace/frontend/fortran/ast_utils.py | 38 +- dace/frontend/fortran/fortran_parser.py | 545 ++++++----- .../fortran/icon_config_propagation.py | 17 +- dace/frontend/fortran/intrinsics.py | 397 ++++---- tests/fortran/prune_unused_children_test.py | 9 +- 9 files changed, 1291 insertions(+), 1152 deletions(-) diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index 294f490d39..49a28a3650 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -546,7 +546,8 @@ def function_reference(self, node: Function_Reference): line = get_line(node) return ast_internal_classes.Call_Expr_Node(name=name, args=args.args if args else [], - type="VOID", subroutine=False, + type="VOID", + subroutine=False, line_number=line) def end_associate_stmt(self, node: FASTNode): @@ -621,7 +622,8 @@ def derived_type_def(self, node: FASTNode): else: new_placeholder_offsets[k] = self.placeholders_offsets[k] self.placeholders_offsets = new_placeholder_offsets - return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part, + return ast_internal_classes.Derived_Type_Def_Node(name=name, + component_part=component_part, procedure_part=procedure_part) def derived_type_stmt(self, node: FASTNode): @@ -653,7 +655,7 @@ def write_stmt(self, node: FASTNode): # if node.children[0] is not None: # children = self.create_children(node.children[0]) # if node.children[1] is not None: - # children = self.create_children(node.children[1]) + # children = self.create_children(node.children[1]) line = get_line(node) return ast_internal_classes.Write_Stmt_Node(args=node.string, line_number=line) @@ -714,7 +716,6 @@ def subroutine_subprogram(self, node: FASTNode): type=return_type, line_number=name.line_number, elemental=name.elemental, - ) def end_program_stmt(self, node: FASTNode): @@ -756,10 +757,9 @@ def function_subprogram(self, node: Function_Subprogram): return_type: str = name.type if name.type == 'VOID': assert specification_part - var_decls: List[Var_Decl_Node] = [v - for c in specification_part.specifications if - isinstance(c, Decl_Stmt_Node) - for v in c.vardecl] + var_decls: List[Var_Decl_Node] = [ + v for c in specification_part.specifications if isinstance(c, Decl_Stmt_Node) for v in c.vardecl + ] return_type = singular(v.type for v in var_decls if v.name == return_var.name) return ast_internal_classes.Function_Subprogram_Node( @@ -785,8 +785,12 @@ def function_stmt(self, node: Function_Stmt): ret = get_child(children, ast_internal_classes.Suffix_Node) ret_args = args.args if args else [] - return ast_internal_classes.Function_Stmt_Node( - name=name, args=ret_args, line_number=node.item.span, ret=ret, elemental=elemental, type=ret) + return ast_internal_classes.Function_Stmt_Node(name=name, + args=ret_args, + line_number=node.item.span, + ret=ret, + elemental=elemental, + type=ret) def subroutine_stmt(self, node: FASTNode): # print(self.name_list) @@ -801,7 +805,9 @@ def subroutine_stmt(self, node: FASTNode): ret_args = [] else: ret_args = args.args - return ast_internal_classes.Subroutine_Stmt_Node(name=name, args=ret_args, line_number=node.item.span, + return ast_internal_classes.Subroutine_Stmt_Node(name=name, + args=ret_args, + line_number=node.item.span, elemental=elemental) def ac_value_list(self, node: FASTNode): @@ -815,7 +821,9 @@ def power_expr(self, node: FASTNode): # child 1 is "**" return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name="pow"), args=[children[0], children[2]], - line_number=line, type="REAL", subroutine=False) + line_number=line, + type="REAL", + subroutine=False) def array_constructor(self, node: FASTNode): children = self.create_children(node) @@ -1012,18 +1020,15 @@ def parse_shape_specification(self, dim: f03.Explicit_Shape_Spec, size: List[FAS expr = self.create_ast(dim_expr[0]) offset.append(expr) - fortran_size = ast_internal_classes.BinOp_Node( - lval=self.create_ast(dim_expr[1]), - rval=self.create_ast(dim_expr[0]), - op="-", - type="INTEGER" - ) - size.append(ast_internal_classes.BinOp_Node( - lval=fortran_size, - rval=ast_internal_classes.Int_Literal_Node(value=str(1)), - op="+", - type="INTEGER") - ) + fortran_size = ast_internal_classes.BinOp_Node(lval=self.create_ast(dim_expr[1]), + rval=self.create_ast(dim_expr[0]), + op="-", + type="INTEGER") + size.append( + ast_internal_classes.BinOp_Node(lval=fortran_size, + rval=ast_internal_classes.Int_Literal_Node(value=str(1)), + op="+", + type="INTEGER")) else: raise TypeError("Array dimension must be at most two expressions") @@ -1234,8 +1239,8 @@ def type_declaration_stmt(self, node: FASTNode): attr_size = [attr_size] * len(names) attr_offset = [attr_offset] * len(names) else: - attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape(dimension_spec[0], names, - node.item.span) + attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape( + dimension_spec[0], names, node.item.span) if attr_size is None: raise RuntimeError("Couldn't parse the dimension attribute specification!") @@ -1266,8 +1271,8 @@ def type_declaration_stmt(self, node: FASTNode): attr_size = [attr_size] * len(names) attr_offset = [attr_offset] * len(names) else: - attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape(dimension_spec[0], names, - node.item.span) + attr_size, assumed_vardecls, attr_offset = self.assumed_array_shape( + dimension_spec[0], names, node.item.span) if attr_size is None: raise RuntimeError("Couldn't parse the dimension attribute specification!") @@ -1491,10 +1496,15 @@ def level_2_expr(self, node: FASTNode): if hasattr(children[0], "type"): type = children[0].type if len(children) == 3: - return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, + return ast_internal_classes.BinOp_Node(lval=children[0], + op=children[1], + rval=children[2], + line_number=line, type=type) else: - return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line, + return ast_internal_classes.UnOp_Node(lval=children[1], + op=children[0], + line_number=line, type=children[1].type) def assignment_stmt(self, node: Assignment_Stmt): @@ -1502,10 +1512,15 @@ def assignment_stmt(self, node: Assignment_Stmt): line = get_line(node) if len(children) == 3: - return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, + return ast_internal_classes.BinOp_Node(lval=children[0], + op=children[1], + rval=children[2], + line_number=line, type=children[0].type) else: - return ast_internal_classes.UnOp_Node(lval=children[1], op=children[0], line_number=line, + return ast_internal_classes.UnOp_Node(lval=children[1], + op=children[0], + line_number=line, type=children[1].type) def pointer_assignment_stmt(self, node: FASTNode): @@ -1535,8 +1550,12 @@ def where_construct(self, node: FASTNode): elifs_cond.append(children[current]) elifs_body.append(children[current + 1]) current += 2 - return ast_internal_classes.Where_Construct_Node(body=body, cond=cond, body_else=body_else, - elifs_cond=elifs_cond, elifs_body=elifs_cond, line_number=line) + return ast_internal_classes.Where_Construct_Node(body=body, + cond=cond, + body_else=body_else, + elifs_cond=elifs_cond, + elifs_body=elifs_cond, + line_number=line) def where_construct_stmt(self, node: FASTNode): children = self.create_children(node) @@ -1647,10 +1666,14 @@ def case_construct(self, node: FASTNode): line = get_line(node) if line is None: line = "Unknown:TODO" - cond = ast_internal_classes.BinOp_Node(op=cond_end.op[0], lval=cond_start, rval=cond_end.cond[0], + cond = ast_internal_classes.BinOp_Node(op=cond_end.op[0], + lval=cond_start, + rval=cond_end.cond[0], line_number=line) for j in range(1, len(cond_end.op)): - cond_add = ast_internal_classes.BinOp_Node(op=cond_end.op[j], lval=cond_start, rval=cond_end.cond[j], + cond_add = ast_internal_classes.BinOp_Node(op=cond_end.op[j], + lval=cond_start, + rval=cond_end.cond[j], line_number=line) cond = ast_internal_classes.BinOp_Node(op=".OR.", lval=cond, rval=cond_add, line_number=line) @@ -1662,7 +1685,9 @@ def case_construct(self, node: FASTNode): if isinstance(i, ast_internal_classes.Case_Cond_Node): cond = ast_internal_classes.BinOp_Node(op=i.op[0], lval=cond_start, rval=i.cond[0], line_number=line) for j in range(1, len(i.op)): - cond_add = ast_internal_classes.BinOp_Node(op=i.op[j], lval=cond_start, rval=i.cond[j], + cond_add = ast_internal_classes.BinOp_Node(op=i.op[j], + lval=cond_start, + rval=i.cond[j], line_number=line) cond = ast_internal_classes.BinOp_Node(op=".OR.", lval=cond, rval=cond_add, line_number=line) @@ -1817,15 +1842,21 @@ def call_stmt(self, node: FASTNode): # line_number = 42 # else: # line_number = node.item.span - return ast_internal_classes.Call_Expr_Node(name=name, args=ret_args, type="VOID", subroutine=True, + return ast_internal_classes.Call_Expr_Node(name=name, + args=ret_args, + type="VOID", + subroutine=True, line_number=line_number) def return_stmt(self, node: FASTNode): return None def stop_stmt(self, node: FASTNode): - return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name="__dace_exit"), args=[], - type="VOID", subroutine=False, line_number=node.item.span) + return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name="__dace_exit"), + args=[], + type="VOID", + subroutine=False, + line_number=node.item.span) def dummy_arg_list(self, node: FASTNode): children = self.create_children(node) @@ -1843,8 +1874,7 @@ def part_ref(self, node: FASTNode): line = get_line(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Section_Subscript_List_Node) - return ast_internal_classes.Array_Subscript_Node(name=name, type="VOID", indices=args.list, - line_number=line) + return ast_internal_classes.Array_Subscript_Node(name=name, type="VOID", indices=args.list, line_number=line) def loop_control(self, node: FASTNode): children = self.create_children(node) @@ -1864,7 +1894,9 @@ def loop_control(self, node: FASTNode): init_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="=", rval=loop_start, type="INTEGER") if isinstance(loop_step, ast_internal_classes.UnOp_Node): if loop_step.op == "-": - cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op=">=", rval=loop_end, + cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, + op=">=", + rval=loop_end, type="INTEGER") else: cond_expr = ast_internal_classes.BinOp_Node(lval=iteration_variable, op="<=", rval=loop_end, type="INTEGER") @@ -1885,16 +1917,16 @@ def block_nonlabel_do_construct(self, node: FASTNode): if do is None: while_true_header = get_child(children, ast_internal_classes.While_True_Control) if while_true_header is not None: - return ast_internal_classes.While_Stmt_Node(name=while_true_header.name, - body=ast_internal_classes.Execution_Part_Node( - execution=body), - line_number=while_true_header.line_number) + return ast_internal_classes.While_Stmt_Node( + name=while_true_header.name, + body=ast_internal_classes.Execution_Part_Node(execution=body), + line_number=while_true_header.line_number) while_header = get_child(children, ast_internal_classes.While_Control) if while_header is not None: - return ast_internal_classes.While_Stmt_Node(cond=while_header.cond, - body=ast_internal_classes.Execution_Part_Node( - execution=body), - line_number=while_header.line_number) + return ast_internal_classes.While_Stmt_Node( + cond=while_header.cond, + body=ast_internal_classes.Execution_Part_Node(execution=body), + line_number=while_header.line_number) return ast_internal_classes.For_Stmt_Node(init=do.init, cond=do.cond, iter=do.iter, @@ -1923,7 +1955,7 @@ def specification_part(self, node: FASTNode): tmp = [self.create_ast(i) for i in node.children] typedecls = [ i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node) - or isinstance(i, ast_internal_classes.Derived_Type_Def_Node) + or isinstance(i, ast_internal_classes.Derived_Type_Def_Node) ] symbols = [] iblocks = [] diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 60ab6c9034..7d8b3d58e1 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -27,25 +27,27 @@ from dace.frontend.fortran.ast_utils import singular, children_of_type, atmost_one +ENTRY_POINT_OBJECT = (Main_Program, Subroutine_Subprogram, Function_Subprogram) ENTRY_POINT_OBJECT_TYPES = Union[Main_Program, Subroutine_Subprogram, Function_Subprogram] -SCOPE_OBJECT_TYPES = Union[ - Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, Interface_Block, - Subroutine_Body, Function_Body] -NAMED_STMTS_OF_INTEREST_TYPES = Union[ - Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, Component_Decl, Entity_Decl, - Specific_Binding, Generic_Binding, Interface_Stmt] +SCOPE_OBJECT = (Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, Interface_Block, + Subroutine_Body, Function_Body) +SCOPE_OBJECT_TYPES = Union[Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, + Interface_Block, Subroutine_Body, Function_Body] +NAMED_STMTS_OF_INTEREST = (Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, Component_Decl, + Entity_Decl, Specific_Binding, Generic_Binding, Interface_Stmt) +NAMED_STMTS_OF_INTEREST_TYPES = Union[Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, + Component_Decl, Entity_Decl, Specific_Binding, Generic_Binding, Interface_Stmt] SPEC = Tuple[str, ...] SPEC_TABLE = Dict[SPEC, NAMED_STMTS_OF_INTEREST_TYPES] class TYPE_SPEC: + NO_ATTRS = '' - def __init__(self, - spec: Union[str, SPEC], - attrs: str = NO_ATTRS): + def __init__(self, spec: Union[str, SPEC], attrs: str = NO_ATTRS): if isinstance(spec, str): - spec = (spec,) + spec = (spec, ) self.spec: SPEC = spec self.shape: Tuple[str, ...] = self._parse_shape(attrs) self.optional: bool = 'OPTIONAL' in attrs @@ -93,9 +95,9 @@ def find_name_of_stmt(node: NAMED_STMTS_OF_INTEREST_TYPES) -> Optional[str]: def find_name_of_node(node: Base) -> Optional[str]: """Find the name of the general node if it has one. For anonymous blocks, return `None`.""" - if isinstance(node, NAMED_STMTS_OF_INTEREST_TYPES): + if isinstance(node, NAMED_STMTS_OF_INTEREST): return find_name_of_stmt(node) - stmt = atmost_one(children_of_type(node, NAMED_STMTS_OF_INTEREST_TYPES)) + stmt = atmost_one(children_of_type(node, NAMED_STMTS_OF_INTEREST)) if not stmt: return None return find_name_of_stmt(stmt) @@ -103,7 +105,7 @@ def find_name_of_node(node: Base) -> Optional[str]: def find_scope_ancestor(node: Base) -> Optional[SCOPE_OBJECT_TYPES]: anc = node.parent - while anc and not isinstance(anc, SCOPE_OBJECT_TYPES): + while anc and not isinstance(anc, SCOPE_OBJECT): anc = anc.parent return anc @@ -112,18 +114,18 @@ def find_named_ancestor(node: Base) -> Optional[NAMED_STMTS_OF_INTEREST_TYPES]: anc = find_scope_ancestor(node) if not anc: return None - return atmost_one(children_of_type(anc, NAMED_STMTS_OF_INTEREST_TYPES)) + return atmost_one(children_of_type(anc, NAMED_STMTS_OF_INTEREST)) def lineage(anc: Base, des: Base) -> Optional[Tuple[Base, ...]]: if anc == des: - return (anc,) + return (anc, ) if not des.parent: return None lin = lineage(anc, des.parent) if not lin: return None - return lin + (des,) + return lin + (des, ) def search_scope_spec(node: Base) -> Optional[SPEC]: @@ -134,10 +136,8 @@ def search_scope_spec(node: Base) -> Optional[SPEC]: assert lin par = node.parent # TODO: How many other such cases can there be? - if (isinstance(scope, Derived_Type_Def) - and any( - isinstance(x, (Explicit_Shape_Spec, Component_Initialization, Kind_Selector, Char_Selector)) - for x in lin)): + if (isinstance(scope, Derived_Type_Def) and any( + isinstance(x, (Explicit_Shape_Spec, Component_Initialization, Kind_Selector, Char_Selector)) for x in lin)): # We're using `node` to describe a shape, an initialization etc. inside a type def. So, `node`` must have been # defined earlier. return search_scope_spec(scope) @@ -146,7 +146,7 @@ def search_scope_spec(node: Base) -> Optional[SPEC]: if kw == node: # We're describing a keyword, which is not really an identifiable object. return None - stmt = singular(children_of_type(scope, NAMED_STMTS_OF_INTEREST_TYPES)) + stmt = singular(children_of_type(scope, NAMED_STMTS_OF_INTEREST)) if not find_name_of_stmt(stmt): # If this is an anonymous object, the scope has to be outside. return search_scope_spec(scope.parent) @@ -160,16 +160,17 @@ def find_scope_spec(node: Base) -> SPEC: def ident_spec(node: NAMED_STMTS_OF_INTEREST_TYPES) -> SPEC: + def _ident_spec(_node: NAMED_STMTS_OF_INTEREST_TYPES) -> SPEC: """ Constuct a list of identifier strings that can uniquely determine it through the entire AST. """ - ident_base = (find_name_of_stmt(_node),) + ident_base = (find_name_of_stmt(_node), ) # Find the next named ancestor. anc = find_named_ancestor(_node.parent) if not anc: return ident_base - assert isinstance(anc, NAMED_STMTS_OF_INTEREST_TYPES) + assert isinstance(anc, NAMED_STMTS_OF_INTEREST) return _ident_spec(anc) + ident_base spec = _ident_spec(node) @@ -212,13 +213,13 @@ def search_local_alias_spec(node: Name) -> Optional[SPEC]: kw, _ = par.children if kw == node: return None - return scope_spec + (name,) + return scope_spec + (name, ) def search_real_local_alias_spec_from_spec(loc: SPEC, alias_map: SPEC_TABLE) -> Optional[SPEC]: while len(loc) > 1 and loc not in alias_map: # The name is not immediately available in the current scope, but may be it is in the parent's scope. - loc = loc[:-2] + (loc[-1],) + loc = loc[:-2] + (loc[-1], ) return loc if loc in alias_map else None @@ -234,8 +235,8 @@ def identifier_specs(ast: Program) -> SPEC_TABLE: Maps each identifier of interest in `ast` to its associated node that defines it. """ ident_map: SPEC_TABLE = {} - for stmt in walk(ast, NAMED_STMTS_OF_INTEREST_TYPES): - assert isinstance(stmt, NAMED_STMTS_OF_INTEREST_TYPES) + for stmt in walk(ast, NAMED_STMTS_OF_INTEREST): + assert isinstance(stmt, NAMED_STMTS_OF_INTEREST) if isinstance(stmt, Interface_Stmt) and not find_name_of_stmt(stmt): # There can be anonymous blocks, e.g., interface blocks, which cannot be identified. continue @@ -254,10 +255,10 @@ def alias_specs(ast: Program): for stmt in walk(ast, Use_Stmt): mod_name = singular(children_of_type(stmt, Name)).string - mod_spec = (mod_name,) + mod_spec = (mod_name, ) scope_spec = find_scope_spec(stmt) - use_spec = scope_spec + (mod_name,) + use_spec = scope_spec + (mod_name, ) assert mod_spec in ident_map # The module's name cannot be used as an identifier in this scope anymore, so just point to the module. @@ -282,7 +283,7 @@ def alias_specs(ast: Program): elif isinstance(c, Rename): _, src, tgt = c.children src, tgt = src.string, tgt.string - src_spec, tgt_spec = scope_spec + (src,), mod_spec + (tgt,) + src_spec, tgt_spec = scope_spec + (src, ), mod_spec + (tgt, ) # `tgt_spec` must have already been resolved if we have sorted the modules properly. assert tgt_spec in alias_map, f"{src_spec} => {tgt_spec}" alias_map[src_spec] = alias_map[tgt_spec] @@ -292,7 +293,7 @@ def alias_specs(ast: Program): def search_real_ident_spec(ident: str, in_spec: SPEC, alias_map: SPEC_TABLE) -> Optional[SPEC]: - k = in_spec + (ident,) + k = in_spec + (ident, ) if k in alias_map: return ident_spec(alias_map[k]) if not in_spec: @@ -308,8 +309,7 @@ def find_real_ident_spec(ident: str, in_spec: SPEC, alias_map: SPEC_TABLE) -> SP def _find_type_decl_node(node: Entity_Decl): anc = node.parent - while anc and not atmost_one( - children_of_type(anc, (Intrinsic_Type_Spec, Declaration_Type_Spec))): + while anc and not atmost_one(children_of_type(anc, (Intrinsic_Type_Spec, Declaration_Type_Spec))): anc = anc.parent return anc @@ -318,7 +318,7 @@ def _eval_selected_int_kind(p: np.int32) -> int: # Copied logic from `replace_int_kind()` elsewhere in the project. # avoid int overflow in numpy 2.0 p = int(p) - kind = int(math.ceil((math.log2(10 ** p) + 1) / 8)) + kind = int(math.ceil((math.log2(10**p) + 1) / 8)) assert kind <= 8 if kind <= 2: return kind @@ -428,7 +428,7 @@ def _eval_int_literal(x: Union[Signed_Int_Literal_Constant, Int_Literal_Constant elif kind in {'1', '2', '4', '8'}: kind = np.int32(kind) else: - kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind,), alias_map) + kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind, ), alias_map) if kind_spec: kind_decl = alias_map[kind_spec] kind_node, _, _, _ = kind_decl.children @@ -455,7 +455,7 @@ def _eval_real_literal(x: Union[Signed_Real_Literal_Constant, Real_Literal_Const else: kind = 4 else: - kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind,), alias_map) + kind_spec = search_real_local_alias_spec_from_spec(find_scope_spec(x) + (kind, ), alias_map) if kind_spec: kind_decl = alias_map[kind_spec] kind_node, _, _, _ = kind_decl.children @@ -489,19 +489,19 @@ def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_ _, iexpr = init.children val = _const_eval_basic_type(iexpr, alias_map) assert val is not None - if typ.spec == ('INTEGER1',): + if typ.spec == ('INTEGER1', ): val = np.int8(val) - elif typ.spec == ('INTEGER2',): + elif typ.spec == ('INTEGER2', ): val = np.int16(val) - elif typ.spec == ('INTEGER4',) or typ.spec == ('INTEGER',): + elif typ.spec == ('INTEGER4', ) or typ.spec == ('INTEGER', ): val = np.int32(val) - elif typ.spec == ('INTEGER8',): + elif typ.spec == ('INTEGER8', ): val = np.int64(val) - elif typ.spec == ('REAL4',) or typ.spec == ('REAL',): + elif typ.spec == ('REAL4', ) or typ.spec == ('REAL', ): val = np.float32(val) - elif typ.spec == ('REAL8',): + elif typ.spec == ('REAL8', ): val = np.float64(val) - elif typ.spec == ('LOGICAL',): + elif typ.spec == ('LOGICAL', ): val = np.bool_(val) else: raise ValueError(f"{expr}/{typ.spec} is not a basic type") @@ -607,7 +607,7 @@ def find_type_of_entity(node: Entity_Decl, alias_map: SPEC_TABLE) -> Optional[TY # TODO: How should we handle character lengths? Just treat it as an extra dimension? if isinstance(kind, Length_Selector): assert typ_name == 'CHARACTER' - extra_dim = (':',) + extra_dim = (':', ) elif isinstance(kind, Kind_Selector): assert typ_name in {'INTEGER', 'REAL', 'LOGICAL'} _, kind, _ = kind.children @@ -618,7 +618,7 @@ def find_type_of_entity(node: Entity_Decl, alias_map: SPEC_TABLE) -> Optional[TY typ_name = f"{typ_name}4" elif typ_name in {'DOUBLE PRECISION'}: typ_name = f"REAL8" - spec = (typ_name,) + spec = (typ_name, ) elif isinstance(typ, Declaration_Type_Spec): _, typ_name = typ.children spec = find_real_ident_spec(typ_name.string, ident_spec(node), alias_map) @@ -750,11 +750,11 @@ def generic_specs(ast: Program) -> Dict[SPEC, Tuple[SPEC, ...]]: plist = [] scope_spec = find_scope_spec(gb) - genc_spec = scope_spec + (bname.string,) + genc_spec = scope_spec + (bname.string, ) proc_specs = [] for pname in plist: - pspec = scope_spec + (pname.string,) + pspec = scope_spec + (pname.string, ) proc_specs.append(pspec) # TODO: Is this assumption true? @@ -774,7 +774,7 @@ def interface_specs(ast: Program, alias_map: SPEC_TABLE) -> Dict[SPEC, Tuple[SPE continue ib = ifs.parent scope_spec = find_scope_spec(ib) - ifspec = scope_spec + (name,) + ifspec = scope_spec + (name, ) # Get the spec of all the callable things in this block that may end up as a resolution for this interface. fns: List[str] = [] @@ -811,7 +811,7 @@ def interface_specs(ast: Program, alias_map: SPEC_TABLE) -> Dict[SPEC, Tuple[SPE cscope = cscope[:-1] fn_spec = find_real_ident_spec(fn_name, cscope, alias_map) assert ifspec != fn_spec - iface_map[ifspec] = (fn_spec,) + iface_map[ifspec] = (fn_spec, ) return iface_map @@ -928,7 +928,7 @@ def correct_for_function_calls(ast: Program): if not Intrinsic_Name.match(name): # There is no way this is an intrinsic call. continue - fref_spec = scope_spec + (name,) + fref_spec = scope_spec + (name, ) if fref_spec in alias_map: # This is already an alias, so intrinsic object is shadowed. continue @@ -1023,6 +1023,7 @@ def _compute_argument_signature(args, scope_spec: SPEC, alias_map: SPEC_TABLE) - args_sig = [] for c in args.children: + def _deduct_type(x) -> TYPE_SPEC: if isinstance(x, (Real_Literal_Constant, Signed_Real_Literal_Constant)): return TYPE_SPEC('REAL') @@ -1123,7 +1124,7 @@ def _deduct_type(x) -> TYPE_SPEC: items = items.children # TODO: We are assuming there is an element. What if there isn't? t = _deduct_type(items[0]) - t.shape += (':',) + t.shape += (':', ) return t else: # TODO: Figure out the actual type. @@ -1139,7 +1140,7 @@ def _deduct_type(x) -> TYPE_SPEC: def _compute_candidate_argument_signature(args, cand_spec: SPEC, alias_map: SPEC_TABLE) -> Tuple[TYPE_SPEC, ...]: cand_args_sig: List[TYPE_SPEC] = [] for ca in args: - ca_decl = alias_map[cand_spec + (ca.string,)] + ca_decl = alias_map[cand_spec + (ca.string, )] ca_type = find_type_of_entity(ca_decl, alias_map) ca_type.keyword = ca.string assert ca_type, f"got: {ca} / {type(ca)}" @@ -1235,7 +1236,7 @@ def deconstruct_interface_calls(ast: Program) -> Program: # At this point, we must have replaced all the interface calls with concrete calls. for use in walk(ast, Use_Stmt): mod_name = singular(children_of_type(use, Name)).string - mod_spec = (mod_name,) + mod_spec = (mod_name, ) olist = atmost_one(children_of_type(use, Only_List)) if not olist: # There is nothing directly referring to the interface. @@ -1272,7 +1273,7 @@ def deconstruct_interface_calls(ast: Program) -> Program: survivors = [] for c in alist.children: assert isinstance(c, Name) - c_spec = scope_spec + (c.string,) + c_spec = scope_spec + (c.string, ) assert c_spec in alias_map if not isinstance(alias_map[c_spec], Interface_Stmt): # Leave the non-interface usages alone. @@ -1298,7 +1299,7 @@ def deconstruct_interface_calls(ast: Program) -> Program: return ast -MATCH_ALL = TYPE_SPEC(('*',), '') # TODO: Hacky; `_does_type_signature_match()` will match anything with this. +MATCH_ALL = TYPE_SPEC(('*', ), '') # TODO: Hacky; `_does_type_signature_match()` will match anything with this. def _does_part_matches(g: TYPE_SPEC, c: TYPE_SPEC) -> bool: @@ -1414,7 +1415,7 @@ def deconstruct_procedure_calls(ast: Program) -> Program: args_sig: Tuple[TYPE_SPEC, ...] = _compute_argument_signature(args, scope_spec, alias_map) all_cand_sigs: List[Tuple[SPEC, Tuple[TYPE_SPEC, ...]]] = [] - bspec = dref_type.spec + (bname.string,) + bspec = dref_type.spec + (bname.string, ) if bspec in genc_map and genc_map[bspec]: for cand in genc_map[bspec]: cand_stmt = alias_map[proc_map[cand]] @@ -1477,13 +1478,14 @@ def prune_unused_objects(ast: Program, keepers: List[SPEC]) -> Program: """ Precondition: All the indirections have been taken out of the program. """ - PRUNABLE_OBJECT_TYPES = Union[Main_Program, Subroutine_Subprogram, Function_Subprogram, Derived_Type_Def] + + PRUNABLE_OBJECTS = (Main_Program, Subroutine_Subprogram, Function_Subprogram, Derived_Type_Def) ident_map = identifier_specs(ast) alias_map = alias_specs(ast) survivors: Set[SPEC] = set() keepers = [alias_map[k].parent for k in keepers] - assert all(isinstance(k, PRUNABLE_OBJECT_TYPES) for k in keepers) + assert all(isinstance(k, PRUNABLE_OBJECTS) for k in keepers) def _keep_from(node: Base): for nm in walk(node, Name): @@ -1498,7 +1500,7 @@ def _keep_from(node: Base): continue survivors.add(anc) anc_node = alias_map[anc].parent - if isinstance(anc_node, PRUNABLE_OBJECT_TYPES): + if isinstance(anc_node, PRUNABLE_OBJECTS): _keep_from(anc_node) to_keep = search_real_ident_spec(nm.string, sc_spec, alias_map) @@ -1507,7 +1509,7 @@ def _keep_from(node: Base): continue survivors.add(to_keep) keep_node = alias_map[to_keep].parent - if isinstance(keep_node, PRUNABLE_OBJECT_TYPES): + if isinstance(keep_node, PRUNABLE_OBJECTS): _keep_from(keep_node) for k in keepers: @@ -1517,7 +1519,7 @@ def _keep_from(node: Base): killed: Set[SPEC] = set() for ns in list(sorted(set(ident_map.keys()) - survivors)): ns_node = ident_map[ns].parent - if not isinstance(ns_node, PRUNABLE_OBJECT_TYPES): + if not isinstance(ns_node, PRUNABLE_OBJECTS): continue for i in range(len(ns) - 1): anc_spec = ns[:i + 1] @@ -1648,7 +1650,7 @@ def assign_globally_unique_subprogram_names(ast: Program, keepers: Set[SPEC]) -> # PHASE 1.a: Remove all the places where any function is imported. for use in walk(ast, Use_Stmt): mod_name = singular(children_of_type(use, Name)).string - mod_spec = (mod_name,) + mod_spec = (mod_name, ) olist = atmost_one(children_of_type(use, Only_List)) if not olist: continue @@ -1689,7 +1691,7 @@ def assign_globally_unique_subprogram_names(ast: Program, keepers: Set[SPEC]) -> # We have chosen to not rename it. continue uname = uident_map[fspec] - ufspec = fspec[:-1] + (uname,) + ufspec = fspec[:-1] + (uname, ) name.string = uname # Find the nearest execution and its correpsonding specification parts. @@ -1787,7 +1789,7 @@ def assign_globally_unique_variable_names(ast: Program, keepers: Set[str]) -> Pr # PHASE 1.a: Remove all the places where any variable is imported. for use in walk(ast, Use_Stmt): mod_name = singular(children_of_type(use, Name)).string - mod_spec = (mod_name,) + mod_spec = (mod_name, ) olist = atmost_one(children_of_type(use, Only_List)) if not olist: continue @@ -1865,7 +1867,7 @@ def assign_globally_unique_variable_names(ast: Program, keepers: Set[str]) -> Pr continue assert len(vspec) == 2 mod, _ = vspec - if not isinstance(alias_map[(mod,)], Module_Stmt): + if not isinstance(alias_map[(mod, )], Module_Stmt): # We can only import modules. continue @@ -1899,7 +1901,7 @@ def assign_globally_unique_variable_names(ast: Program, keepers: Set[str]) -> Pr continue assert len(kind_spec) == 2 mod, _ = kind_spec - if not isinstance(alias_map[(mod,)], Module_Stmt): + if not isinstance(alias_map[(mod, )], Module_Stmt): # We can only import modules. continue @@ -1972,10 +1974,12 @@ def consolidate_uses(ast: Program) -> Program: # Build new use statements. nuses: List[Use_Stmt] = [ Use_Stmt(f"use {k}") if k in all_use else Use_Stmt(f"use {k}, only: {', '.join(use_map[k])}") - for k in use_map.keys() | all_use] + for k in use_map.keys() | all_use + ] reuses: List[Use_Stmt] = [ - Use_Stmt(f"use {k}, only: {', '.join(r for r in use_map[k] if '=>' in r)}") - for k in use_map.keys() if any('=>' in r for r in use_map[k])] + Use_Stmt(f"use {k}, only: {', '.join(r for r in use_map[k] if '=>' in r)}") for k in use_map.keys() + if any('=>' in r for r in use_map[k]) + ] # Remove the old ones, and prepend the new ones. sp.content = nuses + reuses + [c for c in sp.children if not isinstance(c, Use_Stmt)] _reparent_children(sp) @@ -2011,7 +2015,7 @@ def _prune_branches_in_ifblock(ib: If_Construct, alias_map: SPEC_TABLE): isinstance(cut_cond, Else_If_Stmt) cut_cond, _ = cut_cond.children remove_children(ib, ib.children[1:(cut + 1)]) - set_children(ifthen, (cut_cond,)) + set_children(ifthen, (cut_cond, )) _prune_branches_in_ifblock(ib, alias_map) @@ -2022,9 +2026,8 @@ def prune_branches(ast: Program) -> Program: return ast -LITERAL_TYPES = Union[ - Real_Literal_Constant, Signed_Real_Literal_Constant, Int_Literal_Constant, Signed_Int_Literal_Constant, - Logical_Literal_Constant] +LITERAL_TYPES = Union[Real_Literal_Constant, Signed_Real_Literal_Constant, Int_Literal_Constant, + Signed_Int_Literal_Constant, Logical_Literal_Constant] def numpy_type_to_literal(val: NUMPY_TYPES) -> Union[LITERAL_TYPES]: @@ -2052,9 +2055,8 @@ def numpy_type_to_literal(val: NUMPY_TYPES) -> Union[LITERAL_TYPES]: def const_eval_nodes(ast: Program) -> Program: - EXPRESSION_TYPES = Union[ - LITERAL_TYPES, Expr, Add_Operand, Mult_Operand, Level_2_Expr, Level_3_Expr, Level_4_Expr, Level_5_Expr, - Intrinsic_Function_Reference] + EXPRESSION_TYPES = Union[LITERAL_TYPES, Expr, Add_Operand, Mult_Operand, Level_2_Expr, Level_3_Expr, Level_4_Expr, + Level_5_Expr, Intrinsic_Function_Reference] alias_map = alias_specs(ast) @@ -2083,8 +2085,8 @@ def _const_eval_node(n: Base) -> bool: _, kind, _ = knode.children _const_eval_node(kind) - NON_EXPRESSION_TYPES = Union[ - Explicit_Shape_Spec, Loop_Control, Call_Stmt, Function_Reference, Initialization, Component_Initialization] + NON_EXPRESSION_TYPES = Union[Explicit_Shape_Spec, Loop_Control, Call_Stmt, Function_Reference, Initialization, + Component_Initialization] for node in reversed(walk(ast, NON_EXPRESSION_TYPES)): for nm in reversed(walk(node, Name)): _const_eval_node(nm) diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index 2797e05d9d..54475892b6 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -1,26 +1,21 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. from typing import List, Optional, Tuple, Union, Dict, Any - # The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields. # Attributes are not used when walking the tree, but are useful for debugging and for code generation. # The fields attribute is a list of the names of the attributes that are children of the node. class FNode(object): + def __init__(self, line_number: int = -1, **kwargs): # real signature unknown self.line_number = line_number - self.parent: Union[ - None, - Subroutine_Subprogram_Node, - Function_Subprogram_Node, - Main_Program_Node, - Module_Node - ] = None + self.parent: Union[None, Subroutine_Subprogram_Node, Function_Subprogram_Node, Main_Program_Node, + Module_Node] = None for k, v in kwargs.items(): setattr(self, k, v) - _attributes: Tuple[str, ...] = ("line_number",) + _attributes: Tuple[str, ...] = ("line_number", ) _fields: Tuple[str, ...] = () def __eq__(self, o: object) -> bool: @@ -35,6 +30,7 @@ def __eq__(self, o: object) -> bool: class Program_Node(FNode): + def __init__(self, main_program: 'Main_Program_Node', function_definitions: List, @@ -65,6 +61,7 @@ def __init__(self, class BinOp_Node(FNode): + def __init__(self, op: str, lval: FNode, rval: FNode, type: str = 'VOID', **kwargs): super().__init__(**kwargs) assert rval @@ -83,7 +80,7 @@ class UnOp_Node(FNode): 'postfix', 'type', ) - _fields = ('lval',) + _fields = ('lval', ) class Exit_Node(FNode): @@ -92,18 +89,15 @@ class Exit_Node(FNode): class Main_Program_Node(FNode): - _attributes = ("name",) + _attributes = ("name", ) _fields = ("execution_part", "specification_part") class Module_Node(FNode): - def __init__(self, - name: 'Name_Node', - specification_part: 'Specification_Part_Node', + + def __init__(self, name: 'Name_Node', specification_part: 'Specification_Part_Node', subroutine_definitions: List['Subroutine_Subprogram_Node'], - function_definitions: List['Function_Subprogram_Node'], - interface_blocks: Dict, - **kwargs): + function_definitions: List['Function_Subprogram_Node'], interface_blocks: Dict, **kwargs): super().__init__(**kwargs) self.name = name self.specification_part = specification_part @@ -111,20 +105,14 @@ def __init__(self, self.function_definitions = function_definitions self.interface_blocks = interface_blocks - _attributes = ('name',) - _fields = ( - 'specification_part', - 'subroutine_definitions', - 'function_definitions', - 'interface_blocks' - ) + _attributes = ('name', ) + _fields = ('specification_part', 'subroutine_definitions', 'function_definitions', 'interface_blocks') class Module_Subprogram_Part_Node(FNode): - def __init__(self, - subroutine_definitions: List['Subroutine_Subprogram_Node'], - function_definitions: List['Function_Subprogram_Node'], - **kwargs): + + def __init__(self, subroutine_definitions: List['Subroutine_Subprogram_Node'], + function_definitions: List['Function_Subprogram_Node'], **kwargs): super().__init__(**kwargs) self.subroutine_definitions = subroutine_definitions self.function_definitions = function_definitions @@ -137,10 +125,9 @@ def __init__(self, class Internal_Subprogram_Part_Node(FNode): - def __init__(self, - subroutine_definitions: List['Subroutine_Subprogram_Node'], - function_definitions: List['Function_Subprogram_Node'], - **kwargs): + + def __init__(self, subroutine_definitions: List['Subroutine_Subprogram_Node'], + function_definitions: List['Function_Subprogram_Node'], **kwargs): super().__init__(**kwargs) self.subroutine_definitions = subroutine_definitions self.function_definitions = function_definitions @@ -160,15 +147,9 @@ class Actual_Arg_Spec_Node(FNode): class Function_Subprogram_Node(FNode): - def __init__(self, - name: 'Name_Node', - args: List, - ret: 'Name_Node', - specification_part: 'Specification_Part_Node', - execution_part: 'Execution_Part_Node', - type: str, - elemental: bool, - **kwargs): + + def __init__(self, name: 'Name_Node', args: List, ret: 'Name_Node', specification_part: 'Specification_Part_Node', + execution_part: 'Execution_Part_Node', type: str, elemental: bool, **kwargs): super().__init__(**kwargs) self.name = name self.type = type @@ -187,6 +168,7 @@ def __init__(self, class Subroutine_Subprogram_Node(FNode): + def __init__(self, name: 'Name_Node', args: List, @@ -218,10 +200,8 @@ def __init__(self, class Interface_Block_Node(FNode): - _attributes = ('name',) - _fields = ( - 'subroutines', - ) + _attributes = ('name', ) + _fields = ('subroutines', ) class Interface_Stmt_Node(FNode): @@ -231,30 +211,31 @@ class Interface_Stmt_Node(FNode): class Procedure_Name_List_Node(FNode): _attributes = () - _fields = ('subroutines',) + _fields = ('subroutines', ) class Procedure_Statement_Node(FNode): _attributes = () - _fields = ('namelists',) + _fields = ('namelists', ) class Module_Stmt_Node(FNode): _attributes = () - _fields = ('functions',) + _fields = ('functions', ) class Program_Stmt_Node(FNode): - _attributes = ('name',) + _attributes = ('name', ) _fields = () class Subroutine_Stmt_Node(FNode): - _attributes = ('name',) - _fields = ('args',) + _attributes = ('name', ) + _fields = ('args', ) class Function_Stmt_Node(FNode): + def __init__(self, name: 'Name_Node', args: List[FNode], ret: Optional['Suffix_Node'], elemental: bool, type: str, **kwargs): super().__init__(**kwargs) @@ -265,10 +246,14 @@ def __init__(self, name: 'Name_Node', args: List[FNode], ret: Optional['Suffix_N self.type = type _attributes = ('name', 'elemental', 'type') - _fields = ('args', 'ret',) + _fields = ( + 'args', + 'ret', + ) class Prefix_Node(FNode): + def __init__(self, type: str, elemental: bool, recursive: bool, pure: bool, **kwargs): super().__init__(**kwargs) self.type = type @@ -276,70 +261,105 @@ def __init__(self, type: str, elemental: bool, recursive: bool, pure: bool, **kw self.recursive = recursive self.pure = pure - _attributes = ('elemental', 'recursive', 'pure',) + _attributes = ( + 'elemental', + 'recursive', + 'pure', + ) _fields = () class Name_Node(FNode): + def __init__(self, name: str, type: str = 'VOID', **kwargs): super().__init__(**kwargs) self.name = name self.type = type - _attributes = ('name', 'type',) + _attributes = ( + 'name', + 'type', + ) _fields = () class Name_Range_Node(FNode): - _attributes = ('name', 'type', 'arrname', 'pos',) + _attributes = ( + 'name', + 'type', + 'arrname', + 'pos', + ) _fields = () class Where_Construct_Node(FNode): _attributes = () - _fields = ('main_body', 'main_cond', 'else_body', 'elifs_body', 'elifs_cond',) + _fields = ( + 'main_body', + 'main_cond', + 'else_body', + 'elifs_body', + 'elifs_cond', + ) class Type_Name_Node(FNode): - _attributes = ('name', 'type',) + _attributes = ( + 'name', + 'type', + ) _fields = () class Generic_Binding_Node(FNode): _attributes = () - _fields = ('name', 'binding',) + _fields = ( + 'name', + 'binding', + ) class Specification_Part_Node(FNode): - _fields = ('specifications', 'symbols', 'interface_blocks', 'typedecls', 'enums',) + _fields = ( + 'specifications', + 'symbols', + 'interface_blocks', + 'typedecls', + 'enums', + ) class Stop_Stmt_Node(FNode): - _attributes = ('code',) + _attributes = ('code', ) class Error_Stmt_Node(FNode): - _fields = ('error',) + _fields = ('error', ) class Execution_Part_Node(FNode): - _fields = ('execution',) + _fields = ('execution', ) class Statement_Node(FNode): - _attributes = ('col_offset',) + _attributes = ('col_offset', ) _fields = () class Array_Subscript_Node(FNode): + def __init__(self, name: Name_Node, type: str, indices: List[FNode], **kwargs): super().__init__(**kwargs) self.name = name self.type = type self.indices = indices - _attributes = ('type',) - _fields = ('name', 'indices',) + _attributes = ('type', ) + _fields = ( + 'name', + 'indices', + ) class Type_Decl_Node(Statement_Node): @@ -352,17 +372,17 @@ class Type_Decl_Node(Statement_Node): class Allocate_Shape_Spec_Node(FNode): _attributes = () - _fields = ('sizes',) + _fields = ('sizes', ) class Allocate_Shape_Spec_List(FNode): _attributes = () - _fields = ('shape_list',) + _fields = ('shape_list', ) class Allocation_Node(FNode): - _attributes = ('name',) - _fields = ('shape',) + _attributes = ('name', ) + _fields = ('shape', ) class Continue_Node(FNode): @@ -372,7 +392,7 @@ class Continue_Node(FNode): class Allocate_Stmt_Node(FNode): _attributes = () - _fields = ('allocation_list',) + _fields = ('allocation_list', ) class Symbol_Decl_Node(Statement_Node): @@ -403,11 +423,18 @@ class Symbol_Array_Decl_Node(Statement_Node): class Var_Decl_Node(Statement_Node): - def __init__(self, name: str, type: str, - alloc: Optional[bool] = None, optional: Optional[bool] = None, - sizes: Optional[List] = None, offsets: Optional[List] = None, - init: Optional[FNode] = None, actual_offsets: Optional[List] = None, - typeref: Optional[Any] = None, kind: Optional[Any] = None, + + def __init__(self, + name: str, + type: str, + alloc: Optional[bool] = None, + optional: Optional[bool] = None, + sizes: Optional[List] = None, + offsets: Optional[List] = None, + init: Optional[FNode] = None, + actual_offsets: Optional[List] = None, + typeref: Optional[Any] = None, + kind: Optional[Any] = None, **kwargs): super().__init__(**kwargs) self.name = name @@ -426,28 +453,29 @@ def __init__(self, name: str, type: str, class Arg_List_Node(FNode): - _fields = ('args',) + _fields = ('args', ) class Component_Spec_List_Node(FNode): - _fields = ('args',) + _fields = ('args', ) class Allocate_Object_List_Node(FNode): - _fields = ('list',) + _fields = ('list', ) class Deallocate_Stmt_Node(FNode): - _fields = ('list',) + _fields = ('list', ) class Decl_Stmt_Node(Statement_Node): + def __init__(self, vardecl: List[Var_Decl_Node], **kwargs): super().__init__(**kwargs) self.vardecl = vardecl _attributes = () - _fields = ('vardecl',) + _fields = ('vardecl', ) class VarType: @@ -459,6 +487,7 @@ class Void(VarType): class Literal(FNode): + def __init__(self, value: str, type: str, **kwargs): super().__init__(**kwargs) self.value = value @@ -469,40 +498,47 @@ def __init__(self, value: str, type: str, **kwargs): class Int_Literal_Node(Literal): + def __init__(self, value: str, type='INTEGER', **kwargs): super().__init__(value, type, **kwargs) class Real_Literal_Node(Literal): + def __init__(self, value: str, type='REAL', **kwargs): super().__init__(value, type, **kwargs) class Double_Literal_Node(Literal): + def __init__(self, value: str, type='DOUBLE', **kwargs): super().__init__(value, type, **kwargs) class Bool_Literal_Node(Literal): + def __init__(self, value: str, type='LOGICAL', **kwargs): super().__init__(value, type, **kwargs) class Char_Literal_Node(Literal): + def __init__(self, value: str, type='CHAR', **kwargs): super().__init__(value, type, **kwargs) class Suffix_Node(FNode): + def __init__(self, name: 'Name_Node', **kwargs): super().__init__(**kwargs) self.name = name _attributes = () - _fields = ('name',) + _fields = ('name', ) class Call_Expr_Node(FNode): + def __init__(self, name: 'Name_Node', args: List[FNode], subroutine: bool, type: str, **kwargs): super().__init__(**kwargs) self.name = name @@ -515,23 +551,23 @@ def __init__(self, name: 'Name_Node', args: List[FNode], subroutine: bool, type: class Derived_Type_Stmt_Node(FNode): - _attributes = ('name',) - _fields = ('args',) + _attributes = ('name', ) + _fields = ('args', ) class Derived_Type_Def_Node(FNode): - _attributes = ('name',) + _attributes = ('name', ) _fields = ('component_part', 'procedure_part') class Component_Part_Node(FNode): _attributes = () - _fields = ('component_def_stmts',) + _fields = ('component_def_stmts', ) class Data_Component_Def_Stmt_Node(FNode): _attributes = () - _fields = ('vars',) + _fields = ('vars', ) class Data_Ref_Node(FNode): @@ -541,16 +577,16 @@ class Data_Ref_Node(FNode): class Array_Constructor_Node(FNode): _attributes = () - _fields = ('value_list',) + _fields = ('value_list', ) class Ac_Value_List_Node(FNode): _attributes = () - _fields = ('value_list',) + _fields = ('value_list', ) class Section_Subscript_List_Node(FNode): - _fields = ('list',) + _fields = ('list', ) class Pointer_Assignment_Stmt_Node(FNode): @@ -594,7 +630,7 @@ class Defer_Shape_Node(FNode): class Component_Initialization_Node(FNode): _attributes = () - _fields = ('init',) + _fields = ('init', ) class Case_Cond_Node(FNode): @@ -614,27 +650,27 @@ class Procedure_Separator_Node(FNode): class Pointer_Object_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Read_Stmt_Node(FNode): _attributes = () - _fields = ('args',) + _fields = ('args', ) class Close_Stmt_Node(FNode): _attributes = () - _fields = ('args',) + _fields = ('args', ) class Open_Stmt_Node(FNode): _attributes = () - _fields = ('args',) + _fields = ('args', ) class Associate_Stmt_Node(FNode): _attributes = () - _fields = ('args',) + _fields = ('args', ) class Associate_Construct_Node(FNode): @@ -644,7 +680,7 @@ class Associate_Construct_Node(FNode): class Association_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Association_Node(FNode): @@ -653,38 +689,38 @@ class Association_Node(FNode): class Connect_Spec_Node(FNode): - _attributes = ('type',) - _fields = ('args',) + _attributes = ('type', ) + _fields = ('args', ) class Close_Spec_Node(FNode): - _attributes = ('type',) - _fields = ('args',) + _attributes = ('type', ) + _fields = ('args', ) class Close_Spec_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class IO_Control_Spec_Node(FNode): - _attributes = ('type',) - _fields = ('args',) + _attributes = ('type', ) + _fields = ('args', ) class IO_Control_Spec_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Connect_Spec_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Nullify_Stmt_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Namelist_Stmt_Node(FNode): @@ -694,12 +730,12 @@ class Namelist_Stmt_Node(FNode): class Namelist_Group_Object_List_Node(FNode): _attributes = () - _fields = ('list',) + _fields = ('list', ) class Bound_Procedures_Node(FNode): _attributes = () - _fields = ('procedures',) + _fields = ('procedures', ) class Specific_Binding_Node(FNode): @@ -708,6 +744,7 @@ class Specific_Binding_Node(FNode): class Parenthesis_Expr_Node(FNode): + def __init__(self, expr: FNode, **kwargs): super().__init__(**kwargs) assert hasattr(expr, 'type') @@ -729,16 +766,12 @@ class Nonlabel_Do_Stmt_Node(FNode): class While_True_Control(FNode): _attributes = () - _fields = ( - 'name', - ) + _fields = ('name', ) class While_Control(FNode): _attributes = () - _fields = ( - 'cond', - ) + _fields = ('cond', ) class While_Stmt_Node(FNode): @@ -760,37 +793,43 @@ class Loop_Control_Node(FNode): class Else_If_Stmt_Node(FNode): _attributes = () - _fields = ('cond',) + _fields = ('cond', ) class Only_List_Node(FNode): _attributes = () - _fields = ('names', 'renames',) + _fields = ( + 'names', + 'renames', + ) class Rename_Node(FNode): _attributes = () - _fields = ('oldname', 'newname',) + _fields = ( + 'oldname', + 'newname', + ) class ParDecl_Node(FNode): - _attributes = ('type',) - _fields = ('range',) + _attributes = ('type', ) + _fields = ('range', ) class Structure_Constructor_Node(FNode): - _attributes = ('type',) + _attributes = ('type', ) _fields = ('name', 'args') class Use_Stmt_Node(FNode): _attributes = ('name', 'list_all') - _fields = ('list',) + _fields = ('list', ) class Write_Stmt_Node(FNode): _attributes = () - _fields = ('args',) + _fields = ('args', ) class Break_Node(FNode): diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index db5769a977..8a69e9e2a8 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -42,7 +42,9 @@ def is_struct(self, type_name: str): def get_definition(self, type_name: str): return self.structures[type_name] - def find_definition(self, scope_vars, node: ast_internal_classes.Data_Ref_Node, + def find_definition(self, + scope_vars, + node: ast_internal_classes.Data_Ref_Node, variable_name: Optional[ast_internal_classes.Name_Node] = None): # we assume starting from the top (left-most) data_ref_node @@ -249,11 +251,16 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): if i.name.name == self.current_class.name.name: for j in i.procedure_part.procedures: if j.name.name == node.name.name: - return ast_internal_classes.Call_Expr_Node( - name=ast_internal_classes.Name_Node(name=i.name.name + "_" + node.name.name, - type=node.type, args=node.args, - line_number=node.line_number), args=node.args, - type=node.type, subroutine=node.subroutine, line_number=node.line_number,parent=node.parent) + return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node( + name=i.name.name + "_" + node.name.name, + type=node.type, + args=node.args, + line_number=node.line_number), + args=node.args, + type=node.type, + subroutine=node.subroutine, + line_number=node.line_number, + parent=node.parent) return self.generic_visit(node) @@ -299,6 +306,7 @@ def from_node(node: ast_internal_classes.FNode) -> 'FindFunctionAndSubroutines': class FindNames(NodeVisitor): + def __init__(self): self.names: List[str] = [] @@ -312,6 +320,7 @@ def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_ class FindDefinedNames(NodeVisitor): + def __init__(self): self.names: List[str] = [] @@ -370,7 +379,7 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): self.visit(i) elif isinstance(node.lval, ast_internal_classes.Data_Ref_Node): # if isinstance(node.lval.parent_ref, ast_internal_classes.Name_Node): - # self.nodes.append(node.lval.parent_ref) + # self.nodes.append(node.lval.parent_ref) if isinstance(node.lval.parent_ref, ast_internal_classes.Array_Subscript_Node): # self.nodes.append(node.lval.parent_ref.name) for i in node.lval.parent_ref.indices: @@ -501,6 +510,7 @@ def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_De class StructDependencyLister(NodeVisitor): + def __init__(self, names=None): self.names = names self.structs_used = [] @@ -515,6 +525,7 @@ def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): class StructMemberLister(NodeVisitor): + def __init__(self): self.members = [] self.is_pointer = [] @@ -527,6 +538,7 @@ def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): class FindStructDefs(NodeVisitor): + def __init__(self, name=None): self.name = name self.structs = [] @@ -537,6 +549,7 @@ def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): class FindStructUses(NodeVisitor): + def __init__(self, names=None, target=None): self.names = names self.target = target @@ -573,6 +586,7 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): class StructPointerChecker(NodeVisitor): + def __init__(self, parent_struct, pointed_struct, pointer_name, structs_lister, struct_dep_graph, analysis): self.parent_struct = [parent_struct] self.pointed_struct = [pointed_struct] @@ -623,6 +637,7 @@ def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine class StructPointerEliminator(NodeTransformer): + def __init__(self, parent_struct, pointed_struct, pointer_name): self.parent_struct = parent_struct self.pointed_struct = pointed_struct @@ -642,9 +657,10 @@ def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_De else: vardecl.append(k) if vardecl != []: - component_part.component_def_stmts.append(ast_internal_classes.Data_Component_Def_Stmt_Node( - vars=ast_internal_classes.Decl_Stmt_Node(vardecl=vardecl, parent=node.parent), - parent=node.parent)) + component_part.component_def_stmts.append( + ast_internal_classes.Data_Component_Def_Stmt_Node(vars=ast_internal_classes.Decl_Stmt_Node( + vardecl=vardecl, parent=node.parent), + parent=node.parent)) newnode.component_part = component_part return newnode else: @@ -666,8 +682,8 @@ def __init__(self, funcs=None): from dace.frontend.fortran.intrinsics import FortranIntrinsics self.excepted_funcs = [ - "malloc", "pow", "cbrt", "__dace_sign", "tanh", "atan2", - "__dace_epsilon", *FortranIntrinsics.function_names() + "malloc", "pow", "cbrt", "__dace_sign", "tanh", "atan2", "__dace_epsilon", + *FortranIntrinsics.function_names() ] def visit_Structure_Constructor_Node(self, node: ast_internal_classes.Structure_Constructor_Node): @@ -687,9 +703,12 @@ def visit_Structure_Constructor_Node(self, node: ast_internal_classes.Structure_ arg = StructConstructorToFunctionCall(self.funcs).visit(i) processed_args.append(arg) node.args = processed_args - return ast_internal_classes.Call_Expr_Node( - name=ast_internal_classes.Name_Node(name=node.name.name, type="VOID", line_number=node.line_number), - args=node.args, line_number=node.line_number, type="VOID",parent=node.parent) + return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node( + name=node.name.name, type="VOID", line_number=node.line_number), + args=node.args, + line_number=node.line_number, + type="VOID", + parent=node.parent) else: return node @@ -709,9 +728,8 @@ def __init__(self, funcs: FindFunctionAndSubroutines, dict=None): from dace.frontend.fortran.intrinsics import FortranIntrinsics self.excepted_funcs = [ - "malloc", "pow", "cbrt", "__dace_sign", "__dace_allocated", "tanh", "atan2", - "__dace_epsilon", "__dace_exit", "surrtpk", "surrtab", "surrtrf", "abor1", - *FortranIntrinsics.function_names() + "malloc", "pow", "cbrt", "__dace_sign", "__dace_allocated", "tanh", "atan2", "__dace_epsilon", + "__dace_exit", "surrtpk", "surrtab", "surrtrf", "abor1", *FortranIntrinsics.function_names() ] # @@ -748,7 +766,8 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): # TODO Deconproc is a special case, we need to handle it differently - this is just s quick workaround if name.startswith( - "__dace_") or name in self.excepted_funcs or found_in_renames or found_in_names or name in self.funcs.iblocks: + "__dace_" + ) or name in self.excepted_funcs or found_in_renames or found_in_names or name in self.funcs.iblocks: processed_args = [] for i in node.args: arg = CallToArray(self.funcs, self.rename_dict).visit(i) @@ -758,7 +777,9 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): indices = [CallToArray(self.funcs, self.rename_dict).visit(i) for i in node.args] # Array subscript cannot be empty. assert indices - return ast_internal_classes.Array_Subscript_Node(name=node.name, type=node.type, indices=indices, + return ast_internal_classes.Array_Subscript_Node(name=node.name, + type=node.type, + indices=indices, line_number=node.line_number) @@ -772,7 +793,7 @@ def __init__(self): def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node): return - + def visit_If_Then_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node): return @@ -784,7 +805,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics if not stop and node.name.name not in [ - "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() ]: for i in node.args: if isinstance(i, (ast_internal_classes.Name_Node, ast_internal_classes.Literal, @@ -817,16 +838,21 @@ def __init__(self, program, count=0): def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics - if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", - *FortranIntrinsics.call_extraction_exemptions()]: + if node.name.name in [ + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + ]: return self.generic_visit(node) #if node.subroutine: # return self.generic_visit(node) if not hasattr(self, "count"): self.count = 0 tmp = self.count - result = ast_internal_classes.Call_Expr_Node(type=node.type, subroutine=node.subroutine, - name=node.name, args=[], line_number=node.line_number, parent=node.parent) + result = ast_internal_classes.Call_Expr_Node(type=node.type, + subroutine=node.subroutine, + name=node.name, + args=[], + line_number=node.line_number, + parent=node.parent) for i, arg in enumerate(node.args): # Ensure we allow to extract function calls from arguments if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Literal, @@ -869,20 +895,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No node.parent.specification_part.specifications.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ ast_internal_classes.Var_Decl_Node( - name="tmp_arg_" + str(temp), - type=var_type, - sizes=None, - init=None - ) - ]) - ) + name="tmp_arg_" + str(temp), type=var_type, sizes=None, init=None) + ])) newbody.append( ast_internal_classes.BinOp_Node(op="=", - lval=ast_internal_classes.Name_Node(name="tmp_arg_" + - str(temp), + lval=ast_internal_classes.Name_Node(name="tmp_arg_" + str(temp), type=res[i].type), rval=res[i], - line_number=child.line_number,parent=child.parent)) + line_number=child.line_number, + parent=child.parent)) temp = temp - 1 newbody.append(self.visit(child)) @@ -891,6 +912,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No class FunctionCallTransformer(NodeTransformer): + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): if isinstance(node.rval, ast_internal_classes.Call_Expr_Node): if hasattr(node.rval, "subroutine"): @@ -904,11 +926,13 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): lval = node.lval args.append(lval) return (ast_internal_classes.Call_Expr_Node(type=node.rval.type, - name=ast_internal_classes.Name_Node( - name=node.rval.name.name + "_srt", type=node.rval.type), + name=ast_internal_classes.Name_Node(name=node.rval.name.name + + "_srt", + type=node.rval.type), args=args, subroutine=True, - line_number=node.line_number, parent=node.parent)) + line_number=node.line_number, + parent=node.parent)) else: return self.generic_visit(node) @@ -955,23 +979,18 @@ def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Sub if not found: - var = ast_internal_classes.Var_Decl_Node( - name=node.name.name + "__ret", - type='VOID' - ) + var = ast_internal_classes.Var_Decl_Node(name=node.name.name + "__ret", type='VOID') stmt_node = ast_internal_classes.Decl_Stmt_Node(vardecl=[var], line_number=node.line_number) if node.specification_part is not None: node.specification_part.specifications.append(stmt_node) else: - node.specification_part = ast_internal_classes.Specification_Part_Node( - specifications=[stmt_node], - symbols=None, - interface_blocks=None, - uses=None, - typedecls=None, - enums=None - ) + node.specification_part = ast_internal_classes.Specification_Part_Node(specifications=[stmt_node], + symbols=None, + interface_blocks=None, + uses=None, + typedecls=None, + enums=None) # We should always be able to tell a functions return _variable_ (i.e., not type, which we also should be able # to tell). @@ -979,14 +998,15 @@ def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Sub execution_part = NameReplacer(ret.name, node.name.name + "__ret").visit(node.execution_part) args = node.args args.append(ast_internal_classes.Name_Node(name=node.name.name + "__ret", type=node.type)) - return ast_internal_classes.Subroutine_Subprogram_Node( - name=ast_internal_classes.Name_Node(name=node.name.name + "_srt", type=node.type), - args=args, - specification_part=node.specification_part, - execution_part=execution_part, - subroutine=True, - line_number=node.line_number, - elemental=node.elemental) + return ast_internal_classes.Subroutine_Subprogram_Node(name=ast_internal_classes.Name_Node(name=node.name.name + + "_srt", + type=node.type), + args=args, + specification_part=node.specification_part, + execution_part=execution_part, + subroutine=True, + line_number=node.line_number, + elemental=node.elemental) class CallExtractorNodeLister(NodeVisitor): @@ -994,27 +1014,26 @@ class CallExtractorNodeLister(NodeVisitor): Finds all function calls in the AST node and its children that have to be extracted into independent expressions """ - def __init__(self,root=None): + def __init__(self, root=None): self.root = root self.nodes: List[ast_internal_classes.Call_Expr_Node] = [] - def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node): self.generic_visit(node.init) self.generic_visit(node.cond) return - + def visit_If_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node): self.generic_visit(node.cond) return - + def visit_While_Stmt_Node(self, node: ast_internal_classes.While_Stmt_Node): self.generic_visit(node.cond) return def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): stop = False - if self.root==node: + if self.root == node: return self.generic_visit(node) if isinstance(self.root, ast_internal_classes.BinOp_Node): if node == self.root.rval and isinstance(self.root.lval, ast_internal_classes.Name_Node): @@ -1025,7 +1044,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics if not stop and node.name.name not in [ - "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() ]: self.nodes.append(node) #return self.generic_visit(node) @@ -1043,14 +1062,13 @@ class CallExtractor(NodeTransformer): def __init__(self, count=0): self.count = count - - def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): - + from dace.frontend.fortran.intrinsics import FortranIntrinsics - if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", - *FortranIntrinsics.call_extraction_exemptions()]: + if node.name.name in [ + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + ]: return self.generic_visit(node) if hasattr(node, "subroutine"): if node.subroutine is True: @@ -1111,19 +1129,19 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): # interface_blocks=node.interface_blocks) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): - + oldbody = node.execution - changes_made=True + changes_made = True while changes_made: - changes_made=False + changes_made = False newbody = [] for child in oldbody: lister = CallExtractorNodeLister(child) lister.visit(child) res = lister.nodes - - if len(res)> 0: - changes_made=True + + if len(res) > 0: + changes_made = True # Variables are counted from 0...end, starting from main node, to all calls nested # in main node arguments. # However, we need to define nested ones first. @@ -1133,42 +1151,51 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No newbody.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ ast_internal_classes.Var_Decl_Node( - name="tmp_call_" + str(temp), - type=res[i].type, - sizes=None, - init=None - ) + name="tmp_call_" + str(temp), type=res[i].type, sizes=None, init=None) ])) newbody.append( ast_internal_classes.BinOp_Node(op="=", - lval=ast_internal_classes.Name_Node( - name="tmp_call_" + str(temp), type=res[i].type), - rval=res[i], line_number=child.line_number,parent=child.parent)) + lval=ast_internal_classes.Name_Node(name="tmp_call_" + + str(temp), + type=res[i].type), + rval=res[i], + line_number=child.line_number, + parent=child.parent)) temp = temp - 1 if isinstance(child, ast_internal_classes.Call_Expr_Node): new_args = [] for i in child.args: new_args.append(self.visit(i)) - new_child = ast_internal_classes.Call_Expr_Node(type=child.type, subroutine=child.subroutine, - name=child.name, args=new_args, - line_number=child.line_number, parent=child.parent) + new_child = ast_internal_classes.Call_Expr_Node(type=child.type, + subroutine=child.subroutine, + name=child.name, + args=new_args, + line_number=child.line_number, + parent=child.parent) newbody.append(new_child) elif isinstance(child, ast_internal_classes.BinOp_Node): - if isinstance(child.lval,ast_internal_classes.Name_Node) and isinstance (child.rval, ast_internal_classes.Call_Expr_Node): + if isinstance(child.lval, ast_internal_classes.Name_Node) and isinstance( + child.rval, ast_internal_classes.Call_Expr_Node): new_args = [] for i in child.rval.args: new_args.append(self.visit(i)) - new_child = ast_internal_classes.Call_Expr_Node(type=child.rval.type, subroutine=child.rval.subroutine, - name=child.rval.name, args=new_args, - line_number=child.rval.line_number, parent=child.rval.parent) - newbody.append(ast_internal_classes.BinOp_Node(op=child.op, - lval=child.lval, - rval=new_child, line_number=child.line_number,parent=child.parent)) + new_child = ast_internal_classes.Call_Expr_Node(type=child.rval.type, + subroutine=child.rval.subroutine, + name=child.rval.name, + args=new_args, + line_number=child.rval.line_number, + parent=child.rval.parent) + newbody.append( + ast_internal_classes.BinOp_Node(op=child.op, + lval=child.lval, + rval=new_child, + line_number=child.line_number, + parent=child.parent)) else: - newbody.append(self.visit(child)) + newbody.append(self.visit(child)) else: newbody.append(self.visit(child)) - oldbody = newbody + oldbody = newbody return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -1188,10 +1215,8 @@ def __init__(self): def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None): parent_node_types = [ - ast_internal_classes.Subroutine_Subprogram_Node, - ast_internal_classes.Function_Subprogram_Node, - ast_internal_classes.Main_Program_Node, - ast_internal_classes.Module_Node + ast_internal_classes.Subroutine_Subprogram_Node, ast_internal_classes.Function_Subprogram_Node, + ast_internal_classes.Main_Program_Node, ast_internal_classes.Module_Node ] if parent_node is not None and type(parent_node) in parent_node_types: @@ -1383,7 +1408,9 @@ def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_ tmp = tmp + 1 self.count = tmp - return ast_internal_classes.Array_Subscript_Node(name=node.name, type=node.type, indices=newer_indices, + return ast_internal_classes.Array_Subscript_Node(name=node.name, + type=node.type, + indices=newer_indices, line_number=node.line_number) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): @@ -1413,7 +1440,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No init=None, line_number=child.line_number) ], - line_number=child.line_number)) + line_number=child.line_number)) if self.normalize_offsets: # Find the offset of a variable to which we are assigning @@ -1423,8 +1450,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No variable = self.scope_vars.get_var(child.parent, var_name) elif parent_node is not None: struct, variable = self.structures.find_definition( - self.scope_vars, parent_node, j.name - ) + self.scope_vars, parent_node, j.name) var_name = j.name.name else: var_name = j.name.name @@ -1433,8 +1459,8 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No # it can be a symbol - Name_Node - or a value - - if not isinstance(offset, ast_internal_classes.Name_Node) and not isinstance(offset,ast_internal_classes.BinOp_Node): + if not isinstance(offset, ast_internal_classes.Name_Node) and not isinstance( + offset, ast_internal_classes.BinOp_Node): #check if offset is a number try: offset = int(offset) @@ -1442,15 +1468,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No raise ValueError(f"Offset {offset} is not a number") offset = ast_internal_classes.Int_Literal_Node(value=str(offset)) newbody.append( - ast_internal_classes.BinOp_Node( - op="=", - lval=ast_internal_classes.Name_Node(name=tmp_name), - rval=ast_internal_classes.BinOp_Node( - op="-", - lval=self.replacements[tmp_name][0], - rval=offset, - line_number=child.line_number,parent=child.parent), - line_number=child.line_number)) + ast_internal_classes.BinOp_Node(op="=", + lval=ast_internal_classes.Name_Node(name=tmp_name), + rval=ast_internal_classes.BinOp_Node( + op="-", + lval=self.replacements[tmp_name][0], + rval=offset, + line_number=child.line_number, + parent=child.parent), + line_number=child.line_number)) else: newbody.append( ast_internal_classes.BinOp_Node( @@ -1460,8 +1486,10 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No op="-", lval=self.replacements[tmp_name][0], rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number,parent=child.parent), - line_number=child.line_number,parent=child.parent)) + line_number=child.line_number, + parent=child.parent), + line_number=child.line_number, + parent=child.parent)) newbody.append(tmp_child) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -1478,7 +1506,8 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): cond = ast_internal_classes.BinOp_Node(op=">=", rval=ast_internal_classes.Real_Literal_Node(value="0.0"), lval=args[1], - line_number=node.line_number,parent=node.parent) + line_number=node.line_number, + parent=node.parent) body_if = ast_internal_classes.Execution_Part_Node(execution=[ ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval), op="=", @@ -1486,10 +1515,12 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): name=ast_internal_classes.Name_Node(name="abs"), type="DOUBLE", args=[copy.deepcopy(args[0])], - line_number=node.line_number,parent=node.parent, - subroutine=False,), - - line_number=node.line_number,parent=node.parent) + line_number=node.line_number, + parent=node.parent, + subroutine=False, + ), + line_number=node.line_number, + parent=node.parent) ]) body_else = ast_internal_classes.Execution_Part_Node(execution=[ ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval), @@ -1502,14 +1533,18 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): args=[copy.deepcopy(args[0])], type="DOUBLE", subroutine=False, - line_number=node.line_number,parent=node.parent), - line_number=node.line_number,parent=node.parent), - line_number=node.line_number,parent=node.parent) + line_number=node.line_number, + parent=node.parent), + line_number=node.line_number, + parent=node.parent), + line_number=node.line_number, + parent=node.parent) ]) return (ast_internal_classes.If_Stmt_Node(cond=cond, body=body_if, body_else=body_else, - line_number=node.line_number,parent=node.parent)) + line_number=node.line_number, + parent=node.parent)) else: return self.generic_visit(node) @@ -1623,12 +1658,7 @@ def optionalArgsHandleFunction(func): specifiers = [] for i in func.specification_part.specifications: specifiers.append(i) - specifiers.append( - ast_internal_classes.Decl_Stmt_Node( - vardecl=vardecls, - line_number=func.line_number - ) - ) + specifiers.append(ast_internal_classes.Decl_Stmt_Node(vardecl=vardecls, line_number=func.line_number)) func.specification_part.specifications.clear() func.specification_part.specifications = specifiers @@ -1636,6 +1666,7 @@ def optionalArgsHandleFunction(func): class OptionalArgsTransformer(NodeTransformer): + def __init__(self, funcs_with_opt_args): self.funcs_with_opt_args = funcs_with_opt_args @@ -1730,6 +1761,7 @@ def optionalArgsExpander(node=ast_internal_classes.Program_Node): return node + class AllocatableFunctionLister(NodeVisitor): def __init__(self): @@ -1760,12 +1792,13 @@ def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine if len(vars) > 0: self.functions[node.name.name] = vars + class AllocatableReplacerVisitor(NodeVisitor): def __init__(self, functions_with_alloc): self.allocate_var_names = [] self.deallocate_var_names = [] - self.call_nodes = [] + self.call_nodes = [] self.functions_with_alloc = functions_with_alloc def visit_Allocate_Stmt_Node(self, node: ast_internal_classes.Allocate_Stmt_Node): @@ -1783,6 +1816,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): for node.name.name in self.functions_with_alloc: self.call_nodes.append(node) + class AllocatableReplacerTransformer(NodeTransformer): def __init__(self, functions_with_alloc: Dict[str, List[str]]): @@ -1801,27 +1835,21 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No name = f'__f2dace_ALLOCATED_{alloc_node}' newbody.append( - ast_internal_classes.BinOp_Node( - op="=", - lval=ast_internal_classes.Name_Node(name=name), - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number, - parent=child.parent - ) - ) + ast_internal_classes.BinOp_Node(op="=", + lval=ast_internal_classes.Name_Node(name=name), + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number, + parent=child.parent)) for dealloc_node in lister.deallocate_var_names: name = f'__f2dace_ALLOCATED_{dealloc_node}' newbody.append( - ast_internal_classes.BinOp_Node( - op="=", - lval=ast_internal_classes.Name_Node(name=name), - rval=ast_internal_classes.Int_Literal_Node(value="0"), - line_number=child.line_number, - parent=child.parent - ) - ) + ast_internal_classes.BinOp_Node(op="=", + lval=ast_internal_classes.Name_Node(name=name), + rval=ast_internal_classes.Int_Literal_Node(value="0"), + line_number=child.line_number, + parent=child.parent)) for call_node in lister.call_nodes: @@ -1829,15 +1857,12 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No for alloc_name in alloc_nodes: name = f'__f2dace_ALLOCATED_{alloc_name}' - call_node.args.append( - ast_internal_classes.Name_Node(name=name) - ) + call_node.args.append(ast_internal_classes.Name_Node(name=name)) newbody.append(child) return ast_internal_classes.Execution_Part_Node(execution=newbody) - def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): node.execution_part = self.visit(node.execution_part) @@ -1859,23 +1884,20 @@ def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine init = ast_internal_classes.Int_Literal_Node(value="0") # if it's an arg, then we don't initialize - if node.name.name in self.functions_with_alloc and var_decl.name in self.functions_with_alloc[node.name.name]: + if node.name.name in self.functions_with_alloc and var_decl.name in self.functions_with_alloc[ + node.name.name]: init = None - args.append( - ast_internal_classes.Name_Node(name=name) - ) - - var = ast_internal_classes.Var_Decl_Node( - name=name, - type='LOGICAL', - alloc=False, - sizes=None, - offsets=None, - kind=None, - optional=False, - init=init, - line_number=var_decl.line_number - ) + args.append(ast_internal_classes.Name_Node(name=name)) + + var = ast_internal_classes.Var_Decl_Node(name=name, + type='LOGICAL', + alloc=False, + sizes=None, + offsets=None, + kind=None, + optional=False, + init=init, + line_number=var_decl.line_number) newdecls.append(var) if len(newdecls) > 0: @@ -1884,12 +1906,11 @@ def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine if len(newspec) > 0: node.specification_part.specifications.append(*newspec) - return ast_internal_classes.Subroutine_Subprogram_Node( - name=node.name, - args=args, - specification_part=node.specification_part, - execution_part=node.execution_part - ) + return ast_internal_classes.Subroutine_Subprogram_Node(name=node.name, + args=args, + specification_part=node.specification_part, + execution_part=node.execution_part) + def allocatableReplacer(node=ast_internal_classes.Program_Node): @@ -1898,6 +1919,7 @@ def allocatableReplacer(node=ast_internal_classes.Program_Node): return AllocatableReplacerTransformer(visitor.functions).visit(node) + def functionStatementEliminator(node=ast_internal_classes.Program_Node): """ Eliminates function statements from the AST @@ -2048,8 +2070,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, structures: Structures, declaration=True, main_iterator_ranges: Optional[list] = None, - allow_scalars = False - ): + allow_scalars=False): """ Helper function for the transformation of array operations and sums to loops :param node: The AST to be transformed @@ -2102,9 +2123,10 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, dims = len(array_sizes) node = ast_internal_classes.Array_Subscript_Node( - name=cur_node, parent=node.parent, type=var_def.type, - indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims - ) + name=cur_node, + parent=node.parent, + type=var_def.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims) break @@ -2122,7 +2144,8 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, lower_boundary = None if offsets[idx] != 1: # support symbols and integer literals - if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance(offset,ast_internal_classes.BinOp_Node): + if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance( + offset, ast_internal_classes.BinOp_Node): lower_boundary = offsets[idx] else: #check if offset is a number @@ -2148,7 +2171,8 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, upper_boundary = ast_internal_classes.Name_Range_Node(name="f2dace_MAX", type="INTEGER", arrname=ast_internal_classes.Name_Node( - name=array_name, type="VOID", + name=array_name, + type="VOID", line_number=node.line_number), pos=idx) """ @@ -2158,7 +2182,8 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, if offsets[idx] != 1: # support symbols and integer literals - if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance(offset,ast_internal_classes.BinOp_Node): + if isinstance(offsets[idx], ast_internal_classes.Name_Node) or isinstance( + offset, ast_internal_classes.BinOp_Node): offset = offsets[idx] else: try: @@ -2167,16 +2192,9 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, raise ValueError(f"Offset {offsets[idx]} is not a number") offset = ast_internal_classes.Int_Literal_Node(value=str(offset_value)) + upper_boundary = ast_internal_classes.BinOp_Node(lval=upper_boundary, op="+", rval=offset) upper_boundary = ast_internal_classes.BinOp_Node( - lval=upper_boundary, - op="+", - rval=offset - ) - upper_boundary = ast_internal_classes.BinOp_Node( - lval=upper_boundary, - op="-", - rval=ast_internal_classes.Int_Literal_Node(value="1") - ) + lval=upper_boundary, op="-", rval=ast_internal_classes.Int_Literal_Node(value="1")) ranges.append([lower_boundary, upper_boundary]) rangeslen.append(-1) @@ -2200,26 +2218,23 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, if isinstance(end, int) and isinstance(start, int): rangeslen.append(end - start + 1) else: - add = ast_internal_classes.BinOp_Node( - lval=start, - op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1") - ) - substr = ast_internal_classes.BinOp_Node( - lval=end, - op="-", - rval=add - ) + add = ast_internal_classes.BinOp_Node(lval=start, + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1")) + substr = ast_internal_classes.BinOp_Node(lval=end, op="-", rval=add) rangeslen.append(substr) rangepos.append(currentindex) if declaration: newbody.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ - ast_internal_classes.Symbol_Decl_Node( - name="tmp_parfor_" + str(count + len(rangepos) - 1), type="INTEGER", sizes=None, init=None,parent=node.parent, line_number=node.line_number) + ast_internal_classes.Symbol_Decl_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1), + type="INTEGER", + sizes=None, + init=None, + parent=node.parent, + line_number=node.line_number) ])) - """ To account for ranges with different starting offsets inside the same loop, we need to adapt array accesses. @@ -2234,11 +2249,8 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, For LHS, we don't need to adjust - we dictate the loop iterator. """ - indices.append( - ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1)) - ) + indices.append(ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1))) else: - """ For RHS, we adjust starting array position by taking consideration the initial value of the loop iterator. @@ -2248,16 +2260,14 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, current_lower_boundary = main_iterator_ranges[currentindex][0] indices.append( - ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(count + len(rangepos) - 1)), - op="+", - rval=ast_internal_classes.BinOp_Node( - lval=lower_boundary, - op="-", - rval=current_lower_boundary,parent=node.parent - ),parent=node.parent - ) - ) + ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + + str(count + len(rangepos) - 1)), + op="+", + rval=ast_internal_classes.BinOp_Node(lval=lower_boundary, + op="-", + rval=current_lower_boundary, + parent=node.parent), + parent=node.parent)) currentindex += 1 elif allow_scalars: @@ -2294,92 +2304,100 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No res_range = lister.range_nodes #Transpose breaks Array to loop transformation, and fixing it is not trivial - and will likely not involve array to loop at all. - calls=[i for i in mywalk(child) if isinstance(i, ast_internal_classes.Call_Expr_Node)] + calls = [i for i in mywalk(child) if isinstance(i, ast_internal_classes.Call_Expr_Node)] skip_because_of_transpose = False for i in calls: if "__dace_transpose" in i.name.name.lower(): skip_because_of_transpose = True if skip_because_of_transpose: - newbody.append(child) - continue + newbody.append(child) + continue try: - if res is not None and len(res) > 0: - - current = child.lval - ranges = [] - par_Decl_Range_Finder(current, ranges, [], self.count, newbody, self.scope_vars, - self.ast.structures, True) - - # if res_range is not None and len(res_range) > 0: - - # catch cases where an array is used as name, without range expression - visitor = ReplaceImplicitParDecls(self.scope_vars) - child.rval = visitor.visit(child.rval) - - rvals = [i for i in mywalk(child.rval) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] - for i in rvals: - rangesrval = [] - - par_Decl_Range_Finder(i, rangesrval, [], self.count, newbody, self.scope_vars, - self.ast.structures, False, ranges) - for i, j in zip(ranges, rangesrval): - if i != j: - if isinstance(i, list) and isinstance(j, list) and len(i) == len(j): - for k, l in zip(i, j): - if k != l: - if isinstance(k, ast_internal_classes.Name_Range_Node) and isinstance( - l, ast_internal_classes.Name_Range_Node): - if k.name != l.name: - raise NotImplementedError("Ranges must be the same") - else: - # this is not actually illegal. - # raise NotImplementedError("Ranges must be the same") - continue - else: - raise NotImplementedError("Ranges must be identical") - - range_index = 0 - body = ast_internal_classes.BinOp_Node(lval=current, op="=", rval=child.rval, - line_number=child.line_number,parent=child.parent) - - for i in ranges: - initrange = i[0] - finalrange = i[1] - init = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="=", - rval=initrange, - line_number=child.line_number,parent=child.parent) - cond = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="<=", - rval=finalrange, - line_number=child.line_number,parent=child.parent) - iter = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="=", - rval=ast_internal_classes.BinOp_Node( + if res is not None and len(res) > 0: + + current = child.lval + ranges = [] + par_Decl_Range_Finder(current, ranges, [], self.count, newbody, self.scope_vars, + self.ast.structures, True) + + # if res_range is not None and len(res_range) > 0: + + # catch cases where an array is used as name, without range expression + visitor = ReplaceImplicitParDecls(self.scope_vars) + child.rval = visitor.visit(child.rval) + + rvals = [i for i in mywalk(child.rval) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] + for i in rvals: + rangesrval = [] + + par_Decl_Range_Finder(i, rangesrval, [], self.count, newbody, self.scope_vars, + self.ast.structures, False, ranges) + for i, j in zip(ranges, rangesrval): + if i != j: + if isinstance(i, list) and isinstance(j, list) and len(i) == len(j): + for k, l in zip(i, j): + if k != l: + if isinstance(k, ast_internal_classes.Name_Range_Node) and isinstance( + l, ast_internal_classes.Name_Range_Node): + if k.name != l.name: + raise NotImplementedError("Ranges must be the same") + else: + # this is not actually illegal. + # raise NotImplementedError("Ranges must be the same") + continue + else: + raise NotImplementedError("Ranges must be identical") + + range_index = 0 + body = ast_internal_classes.BinOp_Node(lval=current, + op="=", + rval=child.rval, + line_number=child.line_number, + parent=child.parent) + + for i in ranges: + initrange = i[0] + finalrange = i[1] + init = ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1"),parent=child.parent), - line_number=child.line_number,parent=child.parent) - current_for = ast_internal_classes.Map_Stmt_Node( - init=init, - cond=cond, - iter=iter, - body=ast_internal_classes.Execution_Part_Node(execution=[body]), - line_number=child.line_number,parent=child.parent) - body = current_for - range_index += 1 - - newbody.append(body) - - self.count = self.count + range_index - else: - newbody.append(self.visit(child)) + op="=", + rval=initrange, + line_number=child.line_number, + parent=child.parent) + cond = ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="<=", + rval=finalrange, + line_number=child.line_number, + parent=child.parent) + iter = ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + parent=child.parent), + line_number=child.line_number, + parent=child.parent) + current_for = ast_internal_classes.Map_Stmt_Node( + init=init, + cond=cond, + iter=iter, + body=ast_internal_classes.Execution_Part_Node(execution=[body]), + line_number=child.line_number, + parent=child.parent) + body = current_for + range_index += 1 + + newbody.append(body) + + self.count = self.count + range_index + else: + newbody.append(self.visit(child)) except Exception as e: - print("Error in ArrayToLoop, exception caught at line: "+str(child.line_number)) - newbody.append(child) + print("Error in ArrayToLoop, exception caught at line: " + str(child.line_number)) + newbody.append(child) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -2398,6 +2416,7 @@ def mywalk(node): class RenameVar(NodeTransformer): + def __init__(self, oldname: str, newname: str): self.oldname = oldname self.newname = newname @@ -2407,6 +2426,7 @@ def visit_Name_Node(self, node: ast_internal_classes.Name_Node): class PartialRenameVar(NodeTransformer): + def __init__(self, oldname: str, newname: str): self.oldname = oldname self.newname = newname @@ -2423,9 +2443,14 @@ def visit_Name_Node(self, node: ast_internal_classes.Name_Node): type="VOID") if self.oldname in node.name else node def visit_Symbol_Decl_Node(self, node: ast_internal_classes.Symbol_Decl_Node): - return ast_internal_classes.Symbol_Decl_Node(name=node.name.replace(self.oldname, self.newname), type=node.type, - sizes=node.sizes, init=node.init, line_number=node.line_number, - kind=node.kind, alloc=node.alloc, offsets=node.offsets) + return ast_internal_classes.Symbol_Decl_Node(name=node.name.replace(self.oldname, self.newname), + type=node.type, + sizes=node.sizes, + init=node.init, + line_number=node.line_number, + kind=node.kind, + alloc=node.alloc, + offsets=node.offsets) class IfConditionExtractor(NodeTransformer): @@ -2439,7 +2464,7 @@ def __init__(self): def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): newbody = [] for child in node.execution: - + if isinstance(child, ast_internal_classes.If_Stmt_Node): old_cond = child.cond newbody.append( @@ -2447,29 +2472,35 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No ast_internal_classes.Var_Decl_Node( name="_if_cond_" + str(self.count), type="INTEGER", sizes=None, init=None) ])) - newbody.append(ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="_if_cond_" + str(self.count)), - op="=", - rval=old_cond, - line_number=child.line_number, - parent=child.parent)) - newcond = ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="_if_cond_" + str(self.count)), - op="==", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number,parent=old_cond.parent) + newbody.append( + ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="_if_cond_" + + str(self.count)), + op="=", + rval=old_cond, + line_number=child.line_number, + parent=child.parent)) + newcond = ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="_if_cond_" + + str(self.count)), + op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number, + parent=old_cond.parent) newifbody = self.visit(child.body) newelsebody = self.visit(child.body_else) - - newif = ast_internal_classes.If_Stmt_Node(cond=newcond, body=newifbody, body_else=newelsebody, - line_number=child.line_number, parent=child.parent) + + newif = ast_internal_classes.If_Stmt_Node(cond=newcond, + body=newifbody, + body_else=newelsebody, + line_number=child.line_number, + parent=child.parent) self.count += 1 - + newbody.append(newif) else: newbody.append(self.visit(child)) return ast_internal_classes.Execution_Part_Node(execution=newbody) - + class ForDeclarer(NodeTransformer): """ @@ -2494,15 +2525,20 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No final_assign = ast_internal_classes.BinOp_Node(lval=child.init.lval, op="=", rval=child.cond.rval, - line_number=child.line_number,parent=child.parent) + line_number=child.line_number, + parent=child.parent) newfbody = RenameVar(child.init.lval.name, "_for_it_" + str(self.count)).visit(child.body) newcond = RenameVar(child.cond.lval.name, "_for_it_" + str(self.count)).visit(child.cond) newiter = RenameVar(child.iter.lval.name, "_for_it_" + str(self.count)).visit(child.iter) newinit = child.init newinit.lval = RenameVar(child.init.lval.name, "_for_it_" + str(self.count)).visit(child.init.lval) - newfor = ast_internal_classes.For_Stmt_Node(init=newinit, cond=newcond, iter=newiter, body=newfbody, - line_number=child.line_number, parent=child.parent) + newfor = ast_internal_classes.For_Stmt_Node(init=newinit, + cond=newcond, + iter=newiter, + body=newfbody, + line_number=child.line_number, + parent=child.parent) self.count += 1 newfor = self.visit(newfor) newbody.append(newfor) @@ -2547,10 +2583,12 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No shape = ["10"] for i in child.args: if isinstance(i, ast_internal_classes.Name_Node): - newargs.append(ast_internal_classes.Array_Subscript_Node(name=i, indices=[ - ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count))], - line_number=child.line_number, - type=i.type)) + newargs.append( + ast_internal_classes.Array_Subscript_Node( + name=i, + indices=[ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count))], + line_number=child.line_number, + type=i.type)) if i.name.startswith("tmp_call_"): for j in newbody: if isinstance(j, ast_internal_classes.Decl_Stmt_Node): @@ -2560,35 +2598,40 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No else: raise NotImplementedError("Only name nodes are supported") - newbody.append(ast_internal_classes.For_Stmt_Node( - init=ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), - op="=", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number,parent=child.parent), - cond=ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), - op="<=", - rval=ast_internal_classes.Int_Literal_Node(value=shape[0]), - line_number=child.line_number,parent=child.parent), - body=ast_internal_classes.Execution_Part_Node(execution=[ - ast_internal_classes.Call_Expr_Node(type=child.type, - name=child.name, - args=newargs, - line_number=child.line_number,parent=child.parent) - ]), line_number=child.line_number, - iter=ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), - op="=", - rval=ast_internal_classes.BinOp_Node( + newbody.append( + ast_internal_classes.For_Stmt_Node( + init=ast_internal_classes.BinOp_Node( lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), - op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1"),parent=child.parent), - line_number=child.line_number,parent=child.parent) - )) + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number, + parent=child.parent), + cond=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="<=", + rval=ast_internal_classes.Int_Literal_Node(value=shape[0]), + line_number=child.line_number, + parent=child.parent), + body=ast_internal_classes.Execution_Part_Node(execution=[ + ast_internal_classes.Call_Expr_Node(type=child.type, + name=child.name, + args=newargs, + line_number=child.line_number, + parent=child.parent) + ]), + line_number=child.line_number, + iter=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="_for_elem_it_" + str(self.count)), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + parent=child.parent), + line_number=child.line_number, + parent=child.parent))) self.count += 1 - else: newbody.append(self.visit(child)) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -2598,7 +2641,7 @@ class TypeInference(NodeTransformer): """ """ - def __init__(self, ast, assert_voids=True, assign_scopes=True, scope_vars = None): + def __init__(self, ast, assert_voids=True, assign_scopes=True, scope_vars=None): self.assert_voids = assert_voids self.ast = ast @@ -2643,7 +2686,6 @@ def visit_Parenthesis_Expr_Node(self, node: ast_internal_classes.Parenthesis_Exp return node def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): - """ Simple implementation of type promotion in binary ops. """ @@ -2651,14 +2693,7 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): node.lval = self.visit(node.lval) node.rval = self.visit(node.rval) - type_hierarchy = [ - 'VOID', - 'LOGICAL', - 'CHAR', - 'INTEGER', - 'REAL', - 'DOUBLE' - ] + type_hierarchy = ['VOID', 'LOGICAL', 'CHAR', 'INTEGER', 'REAL', 'DOUBLE'] idx_left = type_hierarchy.index(self._get_type(node.lval)) idx_right = type_hierarchy.index(self._get_type(node.rval)) @@ -2691,9 +2726,7 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node): if node.type != 'VOID': return node - struct, variable = self.structures.find_definition( - self.scope_vars, node - ) + struct, variable = self.structures.find_definition(self.scope_vars, node) if variable.type != 'VOID': node.type = variable.type node.dims = len(variable.sizes) if variable.sizes is not None else 1 @@ -2853,18 +2886,12 @@ def visit_Array_Subscript_Node(self, node: ast_internal_classes.Array_Subscript_ original_ref_node = self.nodes[node.name.name] cur_ref_node = original_ref_node - new_ref_node = ast_internal_classes.Data_Ref_Node( - parent_ref=cur_ref_node.parent_ref, - part_ref=None - ) + new_ref_node = ast_internal_classes.Data_Ref_Node(parent_ref=cur_ref_node.parent_ref, part_ref=None) newer_ref_node = new_ref_node while isinstance(cur_ref_node.part_ref, ast_internal_classes.Data_Ref_Node): cur_ref_node = cur_ref_node.part_ref - newest_ref_node = ast_internal_classes.Data_Ref_Node( - parent_ref=cur_ref_node.parent_ref, - part_ref=None - ) + newest_ref_node = ast_internal_classes.Data_Ref_Node(parent_ref=cur_ref_node.parent_ref, part_ref=None) newer_ref_node.part_ref = newest_ref_node newer_ref_node = newest_ref_node @@ -2904,13 +2931,11 @@ def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine else: specification_part = node.specification_part - return ast_internal_classes.Subroutine_Subprogram_Node( - name=node.name, - args=node.args, - specification_part=specification_part, - execution_part=execution_part, - line_number=node.line_number - ) + return ast_internal_classes.Subroutine_Subprogram_Node(name=node.name, + args=node.args, + specification_part=specification_part, + execution_part=execution_part, + line_number=node.line_number) def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification_Part_Node): @@ -2948,13 +2973,11 @@ def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification else: new_symbols = None - return ast_internal_classes.Specification_Part_Node( - specifications=newspec, - symbols=new_symbols, - typedecls=node.typedecls, - uses=node.uses, - enums=node.enums - ) + return ast_internal_classes.Specification_Part_Node(specifications=newspec, + symbols=new_symbols, + typedecls=node.typedecls, + uses=node.uses, + enums=node.enums) class ArgumentPruner(NodeVisitor): @@ -3132,6 +3155,7 @@ def visit_Name_Node(self, node: ast_internal_classes.Name_Node): class IfEvaluator(NodeTransformer): + def __init__(self): self.replacements = 0 @@ -3161,6 +3185,7 @@ def visit_If_Stmt_Node(self, node): class AssignmentLister(NodeTransformer): + def __init__(self, correction=[]): self.simple_assignments = [] self.correction = correction @@ -3179,13 +3204,16 @@ def visit_BinOp_Node(self, node): class AssignmentPropagator(NodeTransformer): + def __init__(self, simple_assignments): self.simple_assignments = simple_assignments self.replacements = 0 def visit_If_Stmt_Node(self, node): test = self.generic_visit(node) - return ast_internal_classes.If_Stmt_Node(line_number=node.line_number, cond=test.cond, body=test.body, + return ast_internal_classes.If_Stmt_Node(line_number=node.line_number, + cond=test.cond, + body=test.body, body_else=test.body_else) def generic_visit(self, node: ast_internal_classes.FNode): @@ -3243,6 +3271,7 @@ def generic_visit(self, node: ast_internal_classes.FNode): class getCalls(NodeVisitor): + def __init__(self): self.calls = [] @@ -3254,6 +3283,7 @@ def visit_Call_Expr_Node(self, node): class FindUnusedFunctions(NodeVisitor): + def __init__(self, root, parse_order): self.root = root self.parse_order = parse_order @@ -3284,16 +3314,15 @@ def visit_Name_Node(self, node: ast_internal_classes.Name_Node): if var.sizes is not None: indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) - return ast_internal_classes.Array_Subscript_Node( - name=node, - type=var.type, - parent=node.parent, - indices=indices, - line_number=node.line_number - ) + return ast_internal_classes.Array_Subscript_Node(name=node, + type=var.type, + parent=node.parent, + indices=indices, + line_number=node.line_number) else: return node + class ReplaceStructArgsLibraryNodesVisitor(NodeVisitor): """ Finds all intrinsic operations that have to be transformed to loops in the AST @@ -3302,10 +3331,7 @@ class ReplaceStructArgsLibraryNodesVisitor(NodeVisitor): def __init__(self): self.nodes: List[ast_internal_classes.FNode] = [] - self.FUNCS_TO_REPLACE = [ - "transpose", - "matmul" - ] + self.FUNCS_TO_REPLACE = ["transpose", "matmul"] def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): @@ -3316,6 +3342,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return + class ReplaceStructArgsLibraryNodes(NodeTransformer): def __init__(self, ast): @@ -3328,10 +3355,7 @@ def __init__(self, ast): self.counter = 0 - FUNCS_TO_REPLACE = [ - "transpose", - "matmul" - ] + FUNCS_TO_REPLACE = ["transpose", "matmul"] # FIXME: copy-paste from intrinsics def _parse_struct_ref(self, node: ast_internal_classes.Data_Ref_Node) -> ast_internal_classes.FNode: @@ -3389,38 +3413,29 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No node.parent.specification_part.specifications.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ ast_internal_classes.Var_Decl_Node( - name=tmp_var_name, - type=var.type, - sizes=var.sizes, - offsets=var.offsets, - init=None - ) - ]) - ) + name=tmp_var_name, type=var.type, sizes=var.sizes, offsets=var.offsets, init=None) + ])) dest_node = ast_internal_classes.Array_Subscript_Node( name=ast_internal_classes.Name_Node(name=tmp_var_name), - parent=call_node.parent, type=var.type, - indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) - ) + parent=call_node.parent, + type=var.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes)) if isinstance(arg.part_ref, ast_internal_classes.Name_Node): arg.part_ref = ast_internal_classes.Array_Subscript_Node( name=arg.part_ref, - parent=call_node.parent, type=arg.part_ref.type, - indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes) - ) + parent=call_node.parent, + type=arg.part_ref.type, + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * len(var.sizes)) newbody.append( - ast_internal_classes.BinOp_Node( - op="=", - lval=dest_node, - rval=arg, - line_number=child.line_number, - parent=child.parent - ) - ) + ast_internal_classes.BinOp_Node(op="=", + lval=dest_node, + rval=arg, + line_number=child.line_number, + parent=child.parent)) self.counter += 1 @@ -3434,4 +3449,3 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No newbody.append(child) return ast_internal_classes.Execution_Part_Node(execution=newbody) - diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index 5de91f71bb..5251fc10bc 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -101,7 +101,7 @@ def get_name(node: ast_internal_classes.FNode): if isinstance(node, ast_internal_classes.Actual_Arg_Spec_Node): actual_node = node.arg else: - actual_node = node + actual_node = node if isinstance(actual_node, ast_internal_classes.Name_Node): return actual_node.name elif isinstance(actual_node, ast_internal_classes.Array_Subscript_Node): @@ -142,8 +142,7 @@ def __init__(self, input_changes: List[str] = None, placeholders={}, placeholders_offsets={}, - rename_dict=None - ): + rename_dict=None): self.outputs = outputs self.outputs_changes = outputs_changes self.sdfg = sdfg @@ -182,7 +181,7 @@ def pardecl2string(self, node: ast_internal_classes.ParDecl_Node): def actualarg2string(self, node: ast_internal_classes.Actual_Arg_Spec_Node): return self.write_code(node.arg) - + def arrayconstructor2string(self, node: ast_internal_classes.Array_Constructor_Node): str_to_return = "[ " for i in node.value_list: @@ -255,8 +254,8 @@ def name2string(self, node): if sdfg_name is None: return name else: - if self.sdfg.arrays[sdfg_name].shape is None or ( - len(self.sdfg.arrays[sdfg_name].shape) == 1 and self.sdfg.arrays[sdfg_name].shape[0] == 1): + if self.sdfg.arrays[sdfg_name].shape is None or (len(self.sdfg.arrays[sdfg_name].shape) == 1 + and self.sdfg.arrays[sdfg_name].shape[0] == 1): return "1" size = self.sdfg.arrays[sdfg_name].shape[location[1]] return self.write_code(str(size)) @@ -267,8 +266,8 @@ def name2string(self, node): if sdfg_name is None: return name else: - if self.sdfg.arrays[sdfg_name].shape is None or ( - len(self.sdfg.arrays[sdfg_name].shape) == 1 and self.sdfg.arrays[sdfg_name].shape[0] == 1): + if self.sdfg.arrays[sdfg_name].shape is None or (len(self.sdfg.arrays[sdfg_name].shape) == 1 + and self.sdfg.arrays[sdfg_name].shape[0] == 1): return "0" offset = self.sdfg.arrays[sdfg_name].offset[location[1]] return self.write_code(str(offset)) @@ -406,7 +405,10 @@ def generate_memlet(op, top_sdfg, state, offset_normalization=False): if i.type == 'ALL': indices.append(None) else: - tw = TaskletWriter([], [], top_sdfg, state.name_mapping, placeholders=state.placeholders, + tw = TaskletWriter([], [], + top_sdfg, + state.name_mapping, + placeholders=state.placeholders, placeholders_offsets=state.placeholders_offsets) text_start = tw.write_code(i.range[0]) text_end = tw.write_code(i.range[1]) @@ -414,7 +416,10 @@ def generate_memlet(op, top_sdfg, state, offset_normalization=False): symb_end = sym.pystr_to_symbolic(text_end) indices.append([symb_start, symb_end]) else: - tw = TaskletWriter([], [], top_sdfg, state.name_mapping, placeholders=state.placeholders, + tw = TaskletWriter([], [], + top_sdfg, + state.name_mapping, + placeholders=state.placeholders, placeholders_offsets=state.placeholders_offsets) text = tw.write_code(i) # This might need to be replaced with the name in the context of the top/current sdfg @@ -426,8 +431,8 @@ def generate_memlet(op, top_sdfg, state, offset_normalization=False): all_indices = indices + [None] * (len(shape) - len(indices)) if offset_normalization: - subset = subsets.Range( - [(i[0], i[1], 1) if i is not None else (0, s - 1, 1) for i, s in zip(all_indices, shape)]) + subset = subsets.Range([(i[0], i[1], 1) if i is not None else (0, s - 1, 1) + for i, s in zip(all_indices, shape)]) else: subset = subsets.Range([(i[0], i[1], 1) if i is not None else (1, s, 1) for i, s in zip(all_indices, shape)]) return subset @@ -493,6 +498,7 @@ def namerange2string(self, node: ast_internal_classes.Name_Range_Node): class Context: + def __init__(self, name): self.name = name self.constants = {} @@ -503,6 +509,7 @@ def __init__(self, name): class NameMap(dict): + def __getitem__(self, k): assert isinstance(k, SDFG) if k not in self: @@ -519,6 +526,7 @@ def __setitem__(self, k, v) -> None: class ModuleMap(dict): + def __getitem__(self, k): assert isinstance(k, ast_internal_classes.Module_Node) if k not in self: @@ -535,6 +543,7 @@ def __setitem__(self, k, v) -> None: class FunctionSubroutineLister: + def __init__(self): self.list_of_functions = [] self.names_in_functions = {} @@ -562,7 +571,6 @@ def get_functions_and_subroutines(self, node: Base): self.names_in_types[name] += list_descendent_typenames(i) self.list_of_types.append(name) - elif isinstance(i, Function_Stmt): fn_name = singular(children_of_type(i, Name)).string self.names_in_functions[fn_name] = list_descendent_names(node) @@ -602,6 +610,7 @@ def get_functions_and_subroutines(self, node: Base): def list_descendent_typenames(node: Base) -> List[str]: + def _list_descendent_typenames(_node: Base, _list_of_names: List[str]) -> List[str]: for c in _node.children: if isinstance(c, Type_Name): @@ -615,6 +624,7 @@ def _list_descendent_typenames(_node: Base, _list_of_names: List[str]) -> List[s def list_descendent_names(node: Base) -> List[str]: + def _list_descendent_names(_node: Base, _list_of_names: List[str]) -> List[str]: for c in _node.children: if isinstance(c, Name): @@ -628,6 +638,7 @@ def _list_descendent_names(_node: Base, _list_of_names: List[str]) -> List[str]: def get_defined_modules(node: Base) -> List[str]: + def _get_defined_modules(_node: Base, _defined_modules: List[str]) -> List[str]: for m in _node.children: if isinstance(m, Module_Stmt): @@ -640,6 +651,7 @@ def _get_defined_modules(_node: Base, _defined_modules: List[str]) -> List[str]: class UseAllPruneList: + def __init__(self, module: str, identifiers: List[str]): """ Keeps a list of referenced identifiers to intersect with the identifiers available in the module. diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 87e5881175..36cc2bbf30 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -1,19 +1,21 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. import copy -from dataclasses import dataclass import os import warnings from copy import deepcopy as dpcp +from dataclasses import dataclass from itertools import chain from pathlib import Path -from typing import List, Optional, Set, Dict, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import networkx as nx -from fparser.common.readfortran import FortranFileReader as ffr, FortranStringReader, FortranFileReader +from fparser.common.readfortran import FortranFileReader as ffr +from fparser.common.readfortran import FortranStringReader from fparser.common.readfortran import FortranStringReader as fsr -from fparser.two.Fortran2003 import Program, Name, Subroutine_Subprogram, Module_Stmt -from fparser.two.parser import ParserFactory as pf, ParserFactory +from fparser.two.Fortran2003 import Module_Stmt, Name, Program +from fparser.two.parser import ParserFactory +from fparser.two.parser import ParserFactory as pf from fparser.two.symbol_table import SymbolTable from fparser.two.utils import Base, walk @@ -21,22 +23,24 @@ import dace.frontend.fortran.ast_internal_classes as ast_internal_classes import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_utils as ast_utils +from dace import SDFG, InterstateEdge from dace import Language as lang -from dace import SDFG, InterstateEdge, Memlet, pointer, nodes, SDFGState +from dace import Memlet, SDFGState from dace import data as dat -from dace import dtypes +from dace import dtypes, nodes, pointer from dace import subsets as subs from dace import symbolic as sym from dace.data import Scalar, Structure -from dace.frontend.fortran.ast_desugaring import SPEC, ENTRY_POINT_OBJECT_TYPES, find_name_of_stmt, find_name_of_node, \ - identifier_specs, append_children, correct_for_function_calls, remove_access_statements, sort_modules, \ - deconstruct_enums, deconstruct_interface_calls, deconstruct_procedure_calls, prune_unused_objects, \ - deconstruct_associations, assign_globally_unique_subprogram_names, assign_globally_unique_variable_names, \ - consolidate_uses, prune_branches, const_eval_nodes, lower_identifier_names, \ - remove_access_statements, ident_spec, NAMED_STMTS_OF_INTEREST_TYPES +from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, ENTRY_POINT_OBJECT_TYPES, NAMED_STMTS_OF_INTEREST, + SPEC, append_children, consolidate_uses, const_eval_nodes, + correct_for_function_calls, deconstruct_associations, + deconstruct_enums, deconstruct_interface_calls, + deconstruct_procedure_calls, find_name_of_node, find_name_of_stmt, + ident_spec, identifier_specs, lower_identifier_names, prune_branches, + prune_unused_objects, remove_access_statements, sort_modules) from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node -from dace.frontend.fortran.ast_utils import UseAllPruneList, children_of_type -from dace.frontend.fortran.intrinsics import IntrinsicSDFGTransformation, NeedsTypeInferenceException +from dace.frontend.fortran.ast_utils import children_of_type +from dace.frontend.fortran.intrinsics import (IntrinsicSDFGTransformation, NeedsTypeInferenceException) from dace.properties import CodeBlock global_struct_instance_counter = 0 @@ -99,17 +103,19 @@ def add_views_recursive(sdfg, name, datatype_to_add, struct_views, name_mapping, sdfg.arrays[name_mapping[name] + join_chain + "_" + i] = view_to_member else: if sdfg.arrays.get(name_mapping[name] + join_chain + "_" + i) is None: - sdfg.add_view(name_mapping[name] + join_chain + "_" + i, datatype_to_add.members[i].shape, - datatype_to_add.members[i].dtype, strides=datatype_to_add.members[i].strides) + sdfg.add_view(name_mapping[name] + join_chain + "_" + i, + datatype_to_add.members[i].shape, + datatype_to_add.members[i].dtype, + strides=datatype_to_add.members[i].strides) if names_of_object_in_parent_sdfg.get(name_mapping[name]) is not None: - if actual_offsets_per_parent_sdfg.get( - names_of_object_in_parent_sdfg[name_mapping[name]] + join_chain + "_" + i) is not None: + if actual_offsets_per_parent_sdfg.get(names_of_object_in_parent_sdfg[name_mapping[name]] + join_chain + + "_" + i) is not None: actual_offsets_per_sdfg[name_mapping[name] + join_chain + "_" + i] = actual_offsets_per_parent_sdfg[ names_of_object_in_parent_sdfg[name_mapping[name]] + join_chain + "_" + i] else: # print("No offsets in sdfg: ",sdfg.name ," for: ",names_of_object_in_parent_sdfg[name_mapping[name]]+ join_chain + "_" + i) - actual_offsets_per_sdfg[name_mapping[name] + join_chain + "_" + i] = [1] * len( - datatype_to_add.members[i].shape) + actual_offsets_per_sdfg[name_mapping[name] + join_chain + "_" + + i] = [1] * len(datatype_to_add.members[i].shape) name_mapping[name_mapping[name] + join_chain + "_" + i] = name_mapping[name] + join_chain + "_" + i struct_views[name_mapping[name] + join_chain + "_" + i] = [name_mapping[name]] + chain + [i] @@ -226,12 +232,13 @@ def add_deferred_shape_assigns_for_structs(structures: ast_transforms.Structures # view=sdfg.arrays[viewname] strides = [dat._prod(shapelist[:i]) for i in range(len(shapelist))] if isinstance(object.members[ast_struct_type.name], dat.ContainerArray): - tmpobject = dat.ContainerArray(object.members[ast_struct_type.name].stype, shape_replace, + tmpobject = dat.ContainerArray(object.members[ast_struct_type.name].stype, + shape_replace, strides=strides) - elif isinstance(object.members[ast_struct_type.name], dat.Array): - tmpobject = dat.Array(object.members[ast_struct_type.name].dtype, shape_replace, + tmpobject = dat.Array(object.members[ast_struct_type.name].dtype, + shape_replace, strides=strides) else: @@ -251,9 +258,15 @@ class AST_translator: This class is responsible for translating the internal AST into a SDFG. """ - def __init__(self, source: str, multiple_sdfgs: bool = False, startpoint=None, sdfg_path=None, - toplevel_subroutine: Optional[str] = None, subroutine_used_names: Optional[Set[str]] = None, - normalize_offsets=False, do_not_make_internal_variables_argument: bool = False): + def __init__(self, + source: str, + multiple_sdfgs: bool = False, + startpoint=None, + sdfg_path=None, + toplevel_subroutine: Optional[str] = None, + subroutine_used_names: Optional[Set[str]] = None, + normalize_offsets=False, + do_not_make_internal_variables_argument: bool = False): """ :ast: The internal fortran AST to be used for translation :source: The source file name from which the AST was generated @@ -480,11 +493,9 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): if decl.name in sdfg.symbols: continue add_deferred_shape_assigns_for_structs(self.structures, decl, sdfg, assign_state, decl.name, - decl.name, self.placeholders, - self.placeholders_offsets, + decl.name, self.placeholders, self.placeholders_offsets, sdfg.arrays[self.name_mapping[sdfg][decl.name]], - self.replace_names, - self.actual_offsets_per_sdfg[sdfg]) + self.replace_names, self.actual_offsets_per_sdfg[sdfg]) if not isinstance(self.startpoint, Main_Program_Node): # this works with CloudSC @@ -506,11 +517,15 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S :param sdfg: The SDFG to which the node should be translated """ if self.name_mapping[sdfg][node.name_pointer.name] in sdfg.arrays: - shapenames = [sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape[i] for i in - range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape))] + shapenames = [ + sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape[i] + for i in range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].shape)) + ] offsetnames = self.actual_offsets_per_sdfg[sdfg][node.name_pointer.name] - [sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset[i] for i in - range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset))] + [ + sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset[i] + for i in range(len(sdfg.arrays[self.name_mapping[sdfg][node.name_pointer.name]].offset)) + ] # for i in shapenames: # if str(i) in sdfg.symbols: # sdfg.symbols.pop(str(i)) @@ -530,14 +545,17 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S raise ValueError("Unknown variable " + node.name_target.name) if isinstance(node.name_target.part_ref, ast_internal_classes.Data_Ref_Node): self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][ - node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name] + node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + + node.name_target.part_ref.part_ref.name] # self.replace_names[node.name_pointer.name]=self.name_mapping[sdfg][node.name_target.parent_ref.name+"_"+node.name_target.part_ref.parent_ref.name+"_"+node.name_target.part_ref.part_ref.name] - target = sdfg.arrays[self.name_mapping[sdfg][ - node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name]] + target = sdfg.arrays[self.name_mapping[sdfg][node.name_target.parent_ref.name + "_" + + node.name_target.part_ref.parent_ref.name + "_" + + node.name_target.part_ref.part_ref.name]] # for i in self.actual_offsets_per_sdfg[sdfg]: # print(i) - actual_offsets = self.actual_offsets_per_sdfg[sdfg][ - node.name_target.parent_ref.name + "_" + node.name_target.part_ref.parent_ref.name + "_" + node.name_target.part_ref.part_ref.name] + actual_offsets = self.actual_offsets_per_sdfg[sdfg][node.name_target.parent_ref.name + "_" + + node.name_target.part_ref.parent_ref.name + "_" + + node.name_target.part_ref.part_ref.name] for i in shapenames: self.replace_names[str(i)] = str(target.shape[shapenames.index(i)]) @@ -546,12 +564,13 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S else: self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][ node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] - self.replace_names[node.name_pointer.name] = self.name_mapping[sdfg][ - node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] - target = sdfg.arrays[ - self.name_mapping[sdfg][node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name]] - actual_offsets = self.actual_offsets_per_sdfg[sdfg][ - node.name_target.parent_ref.name + "_" + node.name_target.part_ref.name] + self.replace_names[node.name_pointer.name] = self.name_mapping[sdfg][node.name_target.parent_ref.name + + "_" + + node.name_target.part_ref.name] + target = sdfg.arrays[self.name_mapping[sdfg][node.name_target.parent_ref.name + "_" + + node.name_target.part_ref.name]] + actual_offsets = self.actual_offsets_per_sdfg[sdfg][node.name_target.parent_ref.name + "_" + + node.name_target.part_ref.name] for i in shapenames: self.replace_names[str(i)] = str(target.shape[shapenames.index(i)]) for i in offsetnames: @@ -596,7 +615,10 @@ def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, offset = [] offset_value = 0 if self.normalize_offsets else -1 for i in k.sizes: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + tw = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) text = tw.write_code(i) @@ -649,7 +671,10 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF sizes = [] offset = [] for j in i.shape.shape_list: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + tw = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) text = tw.write_code(j) @@ -698,10 +723,10 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): sdfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) if self.last_sdfg_states[sdfg] not in [ - self.last_loop_breaks.get(sdfg), - self.last_loop_continues.get(sdfg), - self.last_returns.get(sdfg), - self.already_has_edge_back_continue.get(sdfg) + self.last_loop_breaks.get(sdfg), + self.last_loop_continues.get(sdfg), + self.last_returns.get(sdfg), + self.already_has_edge_back_continue.get(sdfg) ]: body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyIfEnd{name}") sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) @@ -726,15 +751,15 @@ def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG) guard_substate = sdfg.add_state("Guard" + name) final_substate = sdfg.add_state("Merge" + name) self.last_sdfg_states[sdfg] = final_substate - + sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + condition = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.cond) - - begin_loop_state = sdfg.add_state("BeginWhile" + name) end_loop_state = sdfg.add_state("EndWhile" + name) self.last_sdfg_states[sdfg] = begin_loop_state @@ -755,7 +780,6 @@ def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG) else: self.last_loop_continues[sdfg] = None - def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): """ This function is responsible for translating Fortran for statements into a SDFG. @@ -778,19 +802,25 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): iter_name = self.name_mapping[sdfg][decl_node.lval.name] else: raise ValueError("Unknown variable " + decl_node.lval.name) - entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + entry[iter_name] = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(decl_node.rval) sdfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + condition = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.cond) increment = "i+0+1" if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + increment = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.iter.rval) entry = {iter_name: increment} @@ -845,7 +875,10 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): elif isinstance(node.init, ast_internal_classes.Name_Node): self.contexts[sdfg.name].constants[node.name] = self.contexts[sdfg.name].constants[node.init.name] else: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + tw = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) if node.init is not None: @@ -860,7 +893,10 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): self.last_sdfg_states[sdfg] = bstate if node.init is not None: substate = sdfg.add_state(f"Dummystate_{node.name}") - increment = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + increment = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.init) @@ -1060,18 +1096,19 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, raise NotImplementedError("Index in ParDecl should be ALL") else: - text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code( - i) + text = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) local_index_list.append(sym.pystr_to_symbolic(text)) local_strides.pop(local_indices - changed_indices) local_offsets.pop(local_indices - changed_indices) changed_indices += 1 local_indices = local_indices + 1 - local_all_indices = [None] * ( - len(local_shape) - len(local_index_list)) + local_index_list + local_all_indices = [None + ] * (len(local_shape) - len(local_index_list)) + local_index_list if self.normalize_offsets: subset = subs.Range([(i, i, 1) if i is not None else (0, s - 1, 1) for i, s in zip(local_all_indices, local_shape)]) @@ -1084,20 +1121,20 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if len(new_shape) == 0: stype = current_member.stype view_to_container = dat.View.view(current_member) - sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)] = view_to_container + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)] = view_to_container while isinstance(stype, dat.ContainerArray): stype = stype.stype bonus_step = True # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype) view_to_member = dat.View.view(stype) - sdfg.arrays[concatenated_name + "_" + current_member_name + "_m_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + current_member_name + "_m_" + + str(self.struct_view_count)] = view_to_member # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.stype.dtype) else: view_to_member = dat.View.view(current_member) - sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)] = view_to_member # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype,strides=current_member.strides,offset=current_member.offset) already_there_1 = False @@ -1120,9 +1157,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, re, already_there_1 = find_access_in_sources(substate, substate_sources, current_parent_structure_name) - wv, already_there_2 = find_access_in_destinations(substate, substate_destinations, - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)) + wv, already_there_2 = find_access_in_destinations( + substate, substate_destinations, concatenated_name + "_" + current_member_name + + "_" + str(self.struct_view_count)) if not bonus_step: mem = Memlet.simple(current_parent_structure_name + "." + current_member_name, @@ -1131,25 +1168,25 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, else: firstmem = Memlet.simple( current_parent_structure_name + "." + current_member_name, - subs.Range.from_array(sdfg.arrays[ - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)])) - wv2, already_there_22 = find_access_in_destinations(substate, - substate_destinations, - concatenated_name + "_" + current_member_name + "_m_" + str( - self.struct_view_count)) - mem = Memlet.simple(concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count), subset) + subs.Range.from_array( + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)])) + wv2, already_there_22 = find_access_in_destinations( + substate, substate_destinations, concatenated_name + "_" + + current_member_name + "_m_" + str(self.struct_view_count)) + mem = Memlet.simple( + concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count), subset) substate.add_edge(re, None, wv, "views", dpcp(firstmem)) substate.add_edge(wv, None, wv2, "views", dpcp(mem)) if local_name.name in write_names: - wr, already_there_3 = find_access_in_destinations(substate, substate_destinations, - current_parent_structure_name) - rv, already_there_4 = find_access_in_sources(substate, substate_sources, - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)) + wr, already_there_3 = find_access_in_destinations( + substate, substate_destinations, current_parent_structure_name) + rv, already_there_4 = find_access_in_sources( + substate, substate_sources, concatenated_name + "_" + current_member_name + + "_" + str(self.struct_view_count)) if not bonus_step: mem2 = Memlet.simple(current_parent_structure_name + "." + current_member_name, @@ -1158,14 +1195,15 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, else: firstmem = Memlet.simple( current_parent_structure_name + "." + current_member_name, - subs.Range.from_array(sdfg.arrays[ - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)])) - wr2, already_there_33 = find_access_in_sources(substate, substate_sources, - concatenated_name + "_" + current_member_name + "_m_" + str( - self.struct_view_count)) - mem2 = Memlet.simple(concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count), subset) + subs.Range.from_array( + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)])) + wr2, already_there_33 = find_access_in_sources( + substate, substate_sources, concatenated_name + "_" + current_member_name + + "_m_" + str(self.struct_view_count)) + mem2 = Memlet.simple( + concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count), subset) substate.add_edge(wr2, "views", rv, None, dpcp(mem2)) substate.add_edge(rv, "views", wr, None, dpcp(firstmem)) @@ -1233,13 +1271,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if isinstance(array, dat.ContainerArray): view_to_member = dat.View.view(array) - sdfg.arrays[concatenated_name + "_" + array_name + "_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + array_name + "_" + + str(self.struct_view_count)] = view_to_member else: view_to_member = dat.View.view(array) - sdfg.arrays[concatenated_name + "_" + array_name + "_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + array_name + "_" + + str(self.struct_view_count)] = view_to_member # sdfg.add_view(concatenated_name+"_"+array_name+"_"+str(self.struct_view_count),array.shape,array.dtype,strides=array.strides,offset=array.offset) last_view_name_read = None @@ -1271,8 +1309,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, already_there_2 = True break if not already_there_2: - wv = substate.add_write( - concatenated_name + "_" + array_name + "_" + str(self.struct_view_count)) + wv = substate.add_write(concatenated_name + "_" + array_name + "_" + + str(self.struct_view_count)) mem = Memlet.from_array(last_view_name + "." + member_name, array) substate.add_edge(re, None, wv, "views", dpcp(mem)) @@ -1294,8 +1332,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, already_there_4 = True break if not already_there_4: - rv = substate.add_read( - concatenated_name + "_" + array_name + "_" + str(self.struct_view_count)) + rv = substate.add_read(concatenated_name + "_" + array_name + "_" + + str(self.struct_view_count)) mem2 = Memlet.from_array(last_view_name + "." + member_name, array) substate.add_edge(rv, "views", wr, None, dpcp(mem2)) @@ -1337,8 +1375,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, offsets = list(array.offset) self.struct_view_count += 1 - if isinstance(array, dat.ContainerArray) and isinstance(tmpvar, - ast_internal_classes.Array_Subscript_Node): + if isinstance(array, dat.ContainerArray) and isinstance( + tmpvar, ast_internal_classes.Array_Subscript_Node): current_member_name = ast_utils.get_name(tmpvar) current_member = current_parent_structure.members[current_member_name] concatenated_name = "_".join(name_chain) @@ -1359,18 +1397,19 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, else: raise NotImplementedError("Index in ParDecl should be ALL") else: - text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code( - i) + text = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(i) local_index_list.append(sym.pystr_to_symbolic(text)) local_strides.pop(local_indices - changed_indices) local_offsets.pop(local_indices - changed_indices) changed_indices += 1 local_indices = local_indices + 1 - local_all_indices = [None] * ( - len(local_shape) - len(local_index_list)) + local_index_list + local_all_indices = [None] * (len(local_shape) - + len(local_index_list)) + local_index_list if self.normalize_offsets: subset = subs.Range([(i, i, 1) if i is not None else (0, s - 1, 1) for i, s in zip(local_all_indices, local_shape)]) @@ -1386,13 +1425,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, bonus_step = True # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype) view_to_member = dat.View.view(stype) - sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)] = view_to_member # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.stype.dtype) else: view_to_member = dat.View.view(current_member) - sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)] = view_to_member + sdfg.arrays[concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)] = view_to_member # sdfg.add_view(concatenated_name+"_"+current_member_name+"_"+str(self.struct_view_count),current_member.shape,current_member.dtype,strides=current_member.strides,offset=current_member.offset) already_there_1 = False @@ -1423,9 +1462,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, already_there_2 = True break if not already_there_2: - wv = substate.add_write( - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)) + wv = substate.add_write(concatenated_name + "_" + current_member_name + + "_" + str(self.struct_view_count)) if isinstance(current_member, dat.ContainerArray): mem = Memlet.simple(last_view_name, subset) @@ -1450,9 +1488,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, already_there_4 = True break if not already_there_4: - rv = substate.add_read( - concatenated_name + "_" + current_member_name + "_" + str( - self.struct_view_count)) + rv = substate.add_read(concatenated_name + "_" + current_member_name + "_" + + str(self.struct_view_count)) if isinstance(current_member, dat.ContainerArray): mem2 = Memlet.simple(last_view_name, subset) @@ -1491,7 +1528,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.struct_view_count) self.views = self.views + 1 views.append( - [mapped_name_overwrite, wv, rv, variables_in_call.index(variable_in_call)]) + [mapped_name_overwrite, wv, rv, + variables_in_call.index(variable_in_call)]) strides = list(view_to_member.strides) offsets = list(view_to_member.offset) @@ -1509,20 +1547,26 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, else: start = i.range[0] stop = i.range[1] - text_start = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code(start) - text_stop = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code(stop) - shape.append("( "+ text_stop + ") - ( "+ text_start + ") ") - mysize=mysize*sym.pystr_to_symbolic("( "+ text_stop + ") - ( "+ text_start + ") ") + text_start = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(start) + text_stop = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(stop) + shape.append("( " + text_stop + ") - ( " + text_start + ") ") + mysize = mysize * sym.pystr_to_symbolic("( " + text_stop + ") - ( " + + text_start + ") ") index_list.append(None) # raise NotImplementedError("Index in ParDecl should be ALL") else: - text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + text = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(i) @@ -1532,14 +1576,12 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, changed_indices += 1 indices = indices + 1 - - elif isinstance(tmpvar, ast_internal_classes.Name_Node): shape = list(array.shape) else: raise NotImplementedError("Unknown part_ref type") - if shape == () or shape == (1,) or shape == [] or shape == [1]: + if shape == () or shape == (1, ) or shape == [] or shape == [1]: # FIXME 6.03.2024 # print(array,array.__class__.__name__) if isinstance(array, dat.ContainerArray): @@ -1645,24 +1687,28 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, else: start = i.range[0] stop = i.range[1] - text_start = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code( - start) - text_stop = ast_utils.ProcessedWriter(sdfg, self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code( - stop) + text_start = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(start) + text_stop = ast_utils.ProcessedWriter( + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(stop) symb_size = sym.pystr_to_symbolic(text_stop + " - ( " + text_start + " )") shape.append(symb_size) mysize = mysize * symb_size index_list.append( - [sym.pystr_to_symbolic(text_start), sym.pystr_to_symbolic(text_stop)]) + [sym.pystr_to_symbolic(text_start), + sym.pystr_to_symbolic(text_stop)]) # raise NotImplementedError("Index in ParDecl should be ALL") else: - text = ast_utils.ProcessedWriter(sdfg, self.name_mapping, + text = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(i) @@ -1677,7 +1723,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, # print("Data_Ref_Node") # Functionally, this identifies the case where the array is in fact a scalar - if shape == () or shape == (1,) or shape == [] or shape == [1]: + if shape == () or shape == (1, ) or shape == [] or shape == [1]: if hasattr(array, "name") and array.name in self.registered_types: datatype = self.get_dace_type(array.name) datatype_to_add = copy.deepcopy(array) @@ -1697,7 +1743,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, array.storage) else: # This is the case where the array is not a scalar and we need to create a view - if not (shape == () or shape == (1,) or shape == [] or shape == [1]): + if not (shape == () or shape == (1, ) or shape == [] or shape == [1]): offsets_zero = [] for index in offsets: offsets_zero.append(0) @@ -1761,7 +1807,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, shape = array.shape[indices:] - if shape == () or shape == (1,): + if shape == () or shape == (1, ): new_sdfg.add_scalar(self.name_mapping[new_sdfg][local_name.name], array.dtype, array.storage) else: @@ -1807,7 +1853,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, namefinder.visit(i) names_list = namefinder.names # This handles the case where the function is called with read variables found in a module - cached_names=[a[0] for a in self.module_vars] + cached_names = [a[0] for a in self.module_vars] for i in not_found_read_names: if i in names_list: continue @@ -2014,7 +2060,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.all_array_names.append(self.name_mapping[parent_sdfg][i]) array_in_global = self.globalsdfg.arrays[self.name_mapping[self.globalsdfg][i]] if isinstance(array_in_global, Scalar): - parent_sdfg.add_scalar(self.name_mapping[parent_sdfg][i], array_in_global.dtype, + parent_sdfg.add_scalar(self.name_mapping[parent_sdfg][i], + array_in_global.dtype, transient=False) elif (hasattr(array_in_global, 'type') and array_in_global.type == "Array") or isinstance( array_in_global, dat.Array): @@ -2038,12 +2085,10 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.normalize_offsets) if local_name.name in write_names: ast_utils.add_memlet_write(nested_sdfg.parent, self.name_mapping[parent_sdfg][i], - nested_sdfg.parent_nsdfg_node, - self.name_mapping[nested_sdfg][i], memlet) + nested_sdfg.parent_nsdfg_node, self.name_mapping[nested_sdfg][i], memlet) if local_name.name in read_names: ast_utils.add_memlet_read(nested_sdfg.parent, self.name_mapping[parent_sdfg][i], - nested_sdfg.parent_nsdfg_node, - self.name_mapping[nested_sdfg][i], memlet) + nested_sdfg.parent_nsdfg_node, self.name_mapping[nested_sdfg][i], memlet) if not found: nested_sdfg = parent_sdfg parent_sdfg = parent_sdfg.parent_sdfg @@ -2055,10 +2100,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, for k in j.list: if self.contexts.get(new_sdfg.name) is None: self.contexts[new_sdfg.name] = ast_utils.Context(name=new_sdfg.name) - if self.contexts[new_sdfg.name].constants.get( - ast_utils.get_name(k)) is None and self.contexts[ - self.globalsdfg.name].constants.get( - ast_utils.get_name(k)) is not None: + if self.contexts[new_sdfg.name].constants.get(ast_utils.get_name( + k)) is None and self.contexts[self.globalsdfg.name].constants.get( + ast_utils.get_name(k)) is not None: self.contexts[new_sdfg.name].constants[ast_utils.get_name(k)] = self.contexts[ self.globalsdfg.name].constants[ast_utils.get_name(k)] @@ -2099,8 +2143,9 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): if len(calls.nodes) == 1: augmented_call = calls.nodes[0] from dace.frontend.fortran.intrinsics import FortranIntrinsics - if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", - *FortranIntrinsics.retained_function_names()]: + if augmented_call.name.name not in [ + "pow", "atan2", "tanh", "__dace_epsilon", *FortranIntrinsics.retained_function_names() + ]: augmented_call.args.append(node.lval) augmented_call.hasret = True self.call2sdfg(augmented_call, sdfg) @@ -2165,9 +2210,15 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): for i, j, k in zip(output_names, output_names_tasklet, output_names_changed): memlet_range = self.get_memlet_range(sdfg, output_vars, i, j) ast_utils.add_memlet_write(substate, i, tasklet, k, memlet_range) - tw = ast_utils.TaskletWriter(output_names, output_names_changed, sdfg, self.name_mapping, input_names, - input_names_tasklet, placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) + tw = ast_utils.TaskletWriter(output_names, + output_names_changed, + sdfg, + self.name_mapping, + input_names, + input_names_tasklet, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) text = tw.write_code(node) # print(sdfg.name,node.line_number,output_names,output_names_changed,input_names,input_names_tasklet) @@ -2248,9 +2299,13 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): for o, o_t in zip(output_names, output_names_tasklet): output_names_changed.append(o_t + "_out") - tw = ast_utils.TaskletWriter(output_names_tasklet.copy(), output_names_changed.copy(), sdfg, - self.name_mapping, placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) + tw = ast_utils.TaskletWriter(output_names_tasklet.copy(), + output_names_changed.copy(), + sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names) if not isinstance(rettype, ast_internal_classes.Void) and hasret: if isinstance(retval, ast_internal_classes.Name_Node): special_list_in[retval.name] = pointer(self.get_dace_type(rettype)) @@ -2376,28 +2431,35 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): symname = "tmp_struct_symbol_" + str(count) if sdfg.parent_sdfg is not None: sdfg.parent_sdfg.add_symbol("tmp_struct_symbol_" + str(count), dtypes.int32) - sdfg.parent_nsdfg_node.symbol_mapping[ - "tmp_struct_symbol_" + str(count)] = "tmp_struct_symbol_" + str(count) + sdfg.parent_nsdfg_node.symbol_mapping["tmp_struct_symbol_" + + str(count)] = "tmp_struct_symbol_" + str(count) for edge in sdfg.parent.parent_graph.in_edges(sdfg.parent): - assign = ast_utils.ProcessedWriter(sdfg.parent_sdfg, self.name_mapping, + assign = ast_utils.ProcessedWriter(sdfg.parent_sdfg, + self.name_mapping, placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(i) edge.data.assignments["tmp_struct_symbol_" + str(count)] = assign # print(edge) else: - assign = ast_utils.ProcessedWriter(sdfg, self.name_mapping, placeholders=self.placeholders, + assign = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(i) sdfg.append_global_code(f"{dtypes.int32.ctype} {symname};\n") - sdfg.append_init_code( - "tmp_struct_symbol_" + str(count) + "=" + assign.replace(".", "->") + ";\n") - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + sdfg.append_init_code("tmp_struct_symbol_" + str(count) + "=" + assign.replace(".", "->") + + ";\n") + tw = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) text = tw.write_code( - ast_internal_classes.Name_Node(name="tmp_struct_symbol_" + str(count), type="INTEGER", + ast_internal_classes.Name_Node(name="tmp_struct_symbol_" + str(count), + type="INTEGER", line_number=node.line_number)) sizes.append(sym.pystr_to_symbolic(text)) actual_offset_value = node.offsets[node.sizes.index(i)] @@ -2416,7 +2478,10 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): offset.append(offset_value) self.count_of_struct_symbols_lifted += 1 else: - tw = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, placeholders=self.placeholders, + tw = ast_utils.TaskletWriter([], [], + sdfg, + self.name_mapping, + placeholders=self.placeholders, placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names) text = tw.write_code(i) @@ -2538,7 +2603,9 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): if hasattr(node, "init") and node.init is not None: self.translate( ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=node.name, type=node.type), - op="=", rval=node.init, line_number=node.line_number), sdfg) + op="=", + rval=node.init, + line_number=node.line_number), sdfg) def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG): @@ -2551,13 +2618,11 @@ def continue2sdfg(self, node: ast_internal_classes.Continue_Node, sdfg: SDFG): self.already_has_edge_back_continue[sdfg] = self.last_sdfg_states[sdfg] -def create_ast_from_string( - source_string: str, - sdfg_name: str, - transform: bool = False, - normalize_offsets: bool = False, - multiple_sdfgs: bool = False -): +def create_ast_from_string(source_string: str, + sdfg_name: str, + transform: bool = False, + normalize_offsets: bool = False, + multiple_sdfgs: bool = False): """ Creates an AST from a Fortran file in a string :param source_string: The fortran file as a string @@ -2612,6 +2677,7 @@ def create_ast_from_string( class ParseConfig: + def __init__(self, main: Union[None, Path, str] = None, sources: Union[None, List[Path], Dict[str, str]] = None, @@ -2658,8 +2724,10 @@ def create_internal_ast(cfg: ParseConfig) -> Tuple[ast_components.InternalFortra if not cfg.entry_points: # Keep all the possible entry points. - entry_points = [ident_spec(ast_utils.singular(children_of_type(c, NAMED_STMTS_OF_INTEREST_TYPES))) - for c in ast.children if isinstance(c, ENTRY_POINT_OBJECT_TYPES)] + entry_points = [ + ident_spec(ast_utils.singular(children_of_type(c, NAMED_STMTS_OF_INTEREST))) for c in ast.children + if isinstance(c, ENTRY_POINT_OBJECT) + ] else: eps = cfg.entry_points if isinstance(eps, tuple): @@ -2677,6 +2745,7 @@ def create_internal_ast(cfg: ParseConfig) -> Tuple[ast_components.InternalFortra class SDFGConfig: + def __init__(self, entry_points: Dict[str, Union[str, List[str]]], normalize_offsets: bool = True, @@ -2764,8 +2833,8 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] # print(i,is_pointer) if is_pointer: - actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ - (cycle.index(i) + 1) % len(cycle)], point_name) + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker( + i, cycle[(cycle.index(i) + 1) % len(cycle)], point_name) actually_used_pointer_node_finder.visit(program) # print(actually_used_pointer_node_finder.nodes) if len(actually_used_pointer_node_finder.nodes) == 0: @@ -2808,8 +2877,12 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr fn = fn[0] # Do the actual translation. - ast2sdfg = AST_translator(__file__, multiple_sdfgs=cfg.multiple_sdfgs, startpoint=fn, toplevel_subroutine=None, - normalize_offsets=cfg.normalize_offsets, do_not_make_internal_variables_argument=True) + ast2sdfg = AST_translator(__file__, + multiple_sdfgs=cfg.multiple_sdfgs, + startpoint=fn, + toplevel_subroutine=None, + normalize_offsets=cfg.normalize_offsets, + do_not_make_internal_variables_argument=True) g = SDFG(ep) ast2sdfg.functions_and_subroutines = ast_transforms.FindFunctionAndSubroutines.from_node(program).names ast2sdfg.structures = program.structures @@ -2827,11 +2900,11 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr def create_sdfg_from_string( - source_string: str, - sdfg_name: str, - normalize_offsets: bool = True, - multiple_sdfgs: bool = False, - sources: List[str] = None, + source_string: str, + sdfg_name: str, + normalize_offsets: bool = True, + multiple_sdfgs: bool = False, + sources: List[str] = None, ): """ Creates an SDFG from a fortran file in a string @@ -2855,7 +2928,6 @@ def create_sdfg_from_string( program = ast_transforms.CallToArray(functions_and_subroutines_builder).visit(program) program = ast_transforms.IfConditionExtractor().visit(program) program = ast_transforms.CallExtractor().visit(program) - program = ast_transforms.FunctionCallTransformer().visit(program) program = ast_transforms.FunctionToSubroutineDefiner().visit(program) @@ -2916,8 +2988,8 @@ def create_sdfg_from_string( point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] # print(i,is_pointer) if is_pointer: - actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ - (cycle.index(i) + 1) % len(cycle)], point_name) + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker( + i, cycle[(cycle.index(i) + 1) % len(cycle)], point_name) actually_used_pointer_node_finder.visit(program) # print(actually_used_pointer_node_finder.nodes) if len(actually_used_pointer_node_finder.nodes) == 0: @@ -2932,7 +3004,9 @@ def create_sdfg_from_string( # program = # ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes).visit(program) - ast2sdfg = AST_translator(__file__, multiple_sdfgs=multiple_sdfgs, toplevel_subroutine=sdfg_name, + ast2sdfg = AST_translator(__file__, + multiple_sdfgs=multiple_sdfgs, + toplevel_subroutine=sdfg_name, normalize_offsets=normalize_offsets) sdfg = SDFG(sdfg_name) ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names @@ -3139,7 +3213,8 @@ def collect_floating_subprograms(ast: Program, source_list: Dict[str, str], incl return ast -def name_and_rename_dict_creator(parse_order: list,dep_graph:nx.DiGraph)->Tuple[Dict[str, List[str]], Dict[str, Dict[str, str]]]: +def name_and_rename_dict_creator(parse_order: list, + dep_graph: nx.DiGraph) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, str]]]: name_dict = {} rename_dict = {} for i in parse_order: @@ -3171,18 +3246,16 @@ class FindUsedFunctionsConfig: skip_functions: List[str] -def create_sdfg_from_fortran_file_with_options( - cfg: ParseConfig, - ast: Program, - sdfgs_dir, - subroutine_name: Optional[str] = None, - normalize_offsets: bool = True, - propagation_info=None, - enum_propagator_files: Optional[List[str]] = None, - enum_propagator_ast=None, - used_functions_config: Optional[FindUsedFunctionsConfig] = None, - already_parsed_ast=False -): +def create_sdfg_from_fortran_file_with_options(cfg: ParseConfig, + ast: Program, + sdfgs_dir, + subroutine_name: Optional[str] = None, + normalize_offsets: bool = True, + propagation_info=None, + enum_propagator_files: Optional[List[str]] = None, + enum_propagator_ast=None, + used_functions_config: Optional[FindUsedFunctionsConfig] = None, + already_parsed_ast=False): """ Creates an SDFG from a fortran file :param source_string: The fortran file name @@ -3203,7 +3276,7 @@ def create_sdfg_from_fortran_file_with_options( #ast = assign_globally_unique_variable_names(ast, {'config'}) ast = consolidate_uses(ast) else: - ast = correct_for_function_calls(ast) + ast = correct_for_function_calls(ast) dep_graph = compute_dep_graph(ast, 'radiation_interface') parse_order = list(reversed(list(nx.topological_sort(dep_graph)))) @@ -3301,7 +3374,7 @@ def create_sdfg_from_fortran_file_with_options( # program = ast_transforms.CallToArray(functions_and_subroutines_builder, rename_dict).visit(program) # program = ast_transforms.TypeInterference(program).visit(program) # program = ast_transforms.ReplaceInterfaceBlocks(program, functions_and_subroutines_builder).visit(program) - + program = ast_transforms.IfConditionExtractor().visit(program) program = ast_transforms.TypeInference(program, assert_voids=False).visit(program) @@ -3357,8 +3430,8 @@ def create_sdfg_from_fortran_file_with_options( if_eval = ast_transforms.IfEvaluator() program = if_eval.visit(program) replacements += if_eval.replacements - print("Made " + str(replacements) + " replacements in step " + str(step) + " Prop: " + str( - prop.replacements) + " If: " + str(if_eval.replacements)) + print("Made " + str(replacements) + " replacements in step " + str(step) + " Prop: " + + str(prop.replacements) + " If: " + str(if_eval.replacements)) step += 1 if used_functions_config is not None: @@ -3460,8 +3533,8 @@ def create_sdfg_from_fortran_file_with_options( point_name = struct_dep_graph.get_edge_data(i, cycle[(cycle.index(i) + 1) % len(cycle)])["point_name"] # print(i,is_pointer) if is_pointer: - actually_used_pointer_node_finder = ast_transforms.StructPointerChecker(i, cycle[ - (cycle.index(i) + 1) % len(cycle)], point_name, structs_lister, struct_dep_graph, "simple") + actually_used_pointer_node_finder = ast_transforms.StructPointerChecker( + i, cycle[(cycle.index(i) + 1) % len(cycle)], point_name, structs_lister, struct_dep_graph, "simple") actually_used_pointer_node_finder.visit(program) # print(actually_used_pointer_node_finder.nodes) if len(actually_used_pointer_node_finder.nodes) == 0: @@ -3525,16 +3598,19 @@ def create_sdfg_from_fortran_file_with_options( for j in program.subroutine_definitions: if subroutine_name is not None: - if not subroutine_name+"_decon" in j.name.name : - print("Skipping 1 ", j.name.name) - continue + if not subroutine_name + "_decon" in j.name.name: + print("Skipping 1 ", j.name.name) + continue if j.execution_part is None: continue print(f"Building SDFG {j.name.name}") startpoint = j - ast2sdfg = AST_translator(__file__, multiple_sdfgs=False, startpoint=startpoint, sdfg_path=sdfgs_dir, + ast2sdfg = AST_translator(__file__, + multiple_sdfgs=False, + startpoint=startpoint, + sdfg_path=sdfgs_dir, normalize_offsets=normalize_offsets) sdfg = SDFG(j.name.name) ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names @@ -3581,12 +3657,12 @@ def create_sdfg_from_fortran_file_with_options( if subroutine_name is not None: #special for radiation - if subroutine_name=='radiation': - if not 'radiation' == j.name.name : + if subroutine_name == 'radiation': + if not 'radiation' == j.name.name: print("Skipping ", j.name.name) continue - elif not subroutine_name in j.name.name : + elif not subroutine_name in j.name.name: print("Skipping ", j.name.name) continue @@ -3601,8 +3677,7 @@ def create_sdfg_from_fortran_file_with_options( sdfg_path=sdfgs_dir, # toplevel_subroutine_arg_names=arg_pruner.visited_funcs[toplevel_subroutine], # subroutine_used_names=arg_pruner.used_in_all_functions, - normalize_offsets=normalize_offsets - ) + normalize_offsets=normalize_offsets) sdfg = SDFG(j.name.name) ast2sdfg.functions_and_subroutines = functions_and_subroutines_builder.names ast2sdfg.structures = program.structures diff --git a/dace/frontend/fortran/icon_config_propagation.py b/dace/frontend/fortran/icon_config_propagation.py index b6cfa48383..5b978b7b6b 100644 --- a/dace/frontend/fortran/icon_config_propagation.py +++ b/dace/frontend/fortran/icon_config_propagation.py @@ -85,7 +85,7 @@ def parse_assignments(assignments: list[str]) -> list[tuple[str, str]]: ecrad_ast = create_fparser_ast(parse_cfg) already_parsed_ast_bool = False else: - mini_parser=pf().create(std="f2008") + mini_parser = pf().create(std="f2008") ecrad_ast = mini_parser(ffr(file_candidate=already_parsed_ast)) already_parsed_ast_bool = True @@ -200,12 +200,12 @@ def parse_assignments(assignments: list[str]) -> list[tuple[str, str]]: # f"{base_icon_path}/externals/ecrad/ifsaux/ecradhook.F90" # ] - cfg = fortran_parser.FindUsedFunctionsConfig( - root='radiation', - needed_functions=[['radiation_interface', 'radiation']], - skip_functions=['radiation_monochromatic', 'radiation_cloudless_sw', - 'radiation_tripleclouds_sw', 'radiation_homogeneous_sw'] - ) + cfg = fortran_parser.FindUsedFunctionsConfig(root='radiation', + needed_functions=[['radiation_interface', 'radiation']], + skip_functions=[ + 'radiation_monochromatic', 'radiation_cloudless_sw', + 'radiation_tripleclouds_sw', 'radiation_homogeneous_sw' + ]) # generate_propagation_info(propagation_info) @@ -226,5 +226,4 @@ def parse_assignments(assignments: list[str]) -> list[tuple[str, str]]: #enum_propagator_ast=radiation_config_ast, #enum_propagator_files=enum_propagator_files, used_functions_config=cfg, - already_parsed_ast=already_parsed_ast_bool - ) + already_parsed_ast=already_parsed_ast_bool) diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index d5023e71fc..8488a94837 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -18,6 +18,7 @@ FASTNode = Any + class NeedsTypeInferenceException(BaseException): def __init__(self, func_name, line_number): @@ -25,6 +26,7 @@ def __init__(self, func_name, line_number): self.line_number = line_number self.func_name = func_name + class IntrinsicTransformation: @staticmethod @@ -81,11 +83,9 @@ def _parse_struct_ref(self, node: ast_internal_classes.Data_Ref_Node) -> ast_int else: raise NotImplementedError() - def get_var_declaration(self, - parent: ast_internal_classes.FNode, - variable: Union[ - ast_internal_classes.Data_Ref_Node, ast_internal_classes.Name_Node, - ast_internal_classes.Array_Subscript_Node]): + def get_var_declaration(self, parent: ast_internal_classes.FNode, + variable: Union[ast_internal_classes.Data_Ref_Node, ast_internal_classes.Name_Node, + ast_internal_classes.Array_Subscript_Node]): if isinstance(variable, ast_internal_classes.Data_Ref_Node): variable = self._parse_struct_ref(variable) return variable @@ -148,7 +148,8 @@ def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): # we handle extracted call variables this way # but we can also have different shapes, e.g., `maxval(something) > something_else` # hence the check - if isinstance(var, (ast_internal_classes.Name_Node, ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)): + if isinstance(var, (ast_internal_classes.Name_Node, ast_internal_classes.Array_Subscript_Node, + ast_internal_classes.Data_Ref_Node)): var_decl = self.get_var_declaration(var.parent, var) var_decl.type = input_type @@ -171,18 +172,13 @@ def replace_size(transformer: IntrinsicNodeTransformer, var: ast_internal_classe if len(var_decl.sizes) == 1: return (var_decl.sizes[0], "INTEGER") - ret = ast_internal_classes.BinOp_Node( - lval=var_decl.sizes[0], - rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), - op="*" - ) + ret = ast_internal_classes.BinOp_Node(lval=var_decl.sizes[0], + rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), + op="*") cur_node = ret for i in range(1, len(var_decl.sizes) - 1): cur_node.rval = ast_internal_classes.BinOp_Node( - lval=var_decl.sizes[i], - rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), - op="*" - ) + lval=var_decl.sizes[i], rval=ast_internal_classes.Name_Node(name="INTRINSIC_TEMPORARY"), op="*") cur_node = cur_node.rval cur_node.rval = var_decl.sizes[-1] @@ -205,7 +201,7 @@ def _replace_lbound_ubound(func: str, transformer: IntrinsicNodeTransformer, # get variable declaration for the first argument var_decl = transformer.get_var_declaration(var.parent, var.args[0]) - # one arg to LBOUND/UBOUND? not needed currently + # one arg to LBOUND/UBOUND? not needed currently if len(var.args) == 1: raise NotImplementedError() @@ -217,8 +213,9 @@ def _replace_lbound_ubound(func: str, transformer: IntrinsicNodeTransformer, rank_value = int(rank.value) - is_assumed = isinstance(var_decl.offsets[rank_value - 1], ast_internal_classes.Name_Node) and var_decl.offsets[ - rank_value - 1].name.startswith("__f2dace_") + is_assumed = isinstance( + var_decl.offsets[rank_value - 1], + ast_internal_classes.Name_Node) and var_decl.offsets[rank_value - 1].name.startswith("__f2dace_") if func == 'lbound': @@ -245,17 +242,14 @@ def _replace_lbound_ubound(func: str, transformer: IntrinsicNodeTransformer, else: offset = ast_internal_classes.Int_Literal_Node(value=var_decl.offsets[rank_value - 1]) - value = ast_internal_classes.BinOp_Node( - op="+", - lval=size, - rval=ast_internal_classes.BinOp_Node( - op="-", - lval=offset, - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=line - ), - line_number=line - ) + value = ast_internal_classes.BinOp_Node(op="+", + lval=size, + rval=ast_internal_classes.BinOp_Node( + op="-", + lval=offset, + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=line), + line_number=line) return (value, "INTEGER") @@ -288,9 +282,8 @@ def replace_int_kind(args: ast_internal_classes.Arg_List_Node, line, symbols: li raise ValueError("Only symbols can be names in selector") else: raise ValueError("Only literals or symbols can be arguments in selector") - return ast_internal_classes.Int_Literal_Node(value=str( - math.ceil((math.log2(math.pow(10, int(arg0))) + 1) / 8)), - line_number=line) + return ast_internal_classes.Int_Literal_Node(value=str(math.ceil((math.log2(math.pow(10, int(arg0))) + 1) / 8)), + line_number=line) def replace_real_kind(args: ast_internal_classes.Arg_List_Node, line, symbols: list): if isinstance(args.args[0], ast_internal_classes.Int_Literal_Node): @@ -498,9 +491,10 @@ def _skip_result_assignment(self): def _update_result_type(self, var: ast_internal_classes.Name_Node): pass - def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, - arg: ast_internal_classes.FNode, dims_count: Optional[int] = -1 - ) -> ast_internal_classes.Array_Subscript_Node: + def _parse_array(self, + node: ast_internal_classes.Execution_Part_Node, + arg: ast_internal_classes.FNode, + dims_count: Optional[int] = -1) -> ast_internal_classes.Array_Subscript_Node: # supports syntax func(arr) if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Data_Ref_Node)): @@ -520,7 +514,9 @@ def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, dims = len(array_sizes) array_node = ast_internal_classes.Array_Subscript_Node( - name=arg, parent=arg.parent, type='VOID', + name=arg, + parent=arg.parent, + type='VOID', indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims) return array_node @@ -537,7 +533,6 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i Optional[ast_internal_classes.Array_Subscript_Node], ast_internal_classes.BinOp_Node ]: - """ Supports passing binary operations as an input to function. In both cases, we extract the arrays used, and return a brand @@ -596,7 +591,6 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_internal_classes.Array_Subscript_Node, loop_ranges_main: list, loop_ranges_array: list): - """ When given a binary operator with arrays as an argument to the intrinsic, one array will dictate loop range. @@ -614,20 +608,15 @@ def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_inte idx_var = array.indices[i] start_loop = loop_ranges_main[i][0] end_loop = loop_ranges_array[i][0] - - - difference = ast_internal_classes.BinOp_Node( - lval=end_loop, - op="-", - rval=start_loop, - line_number=node.line_number - ) - new_index = ast_internal_classes.BinOp_Node( - lval=idx_var, - op="+", - rval=difference, - line_number=node.line_number - ) + + difference = ast_internal_classes.BinOp_Node(lval=end_loop, + op="-", + rval=start_loop, + line_number=node.line_number) + new_index = ast_internal_classes.BinOp_Node(lval=idx_var, + op="+", + rval=difference, + line_number=node.line_number) array.indices[i] = new_index #difference = int(end_loop.value) - int(start_loop.value) #if difference != 0: @@ -721,7 +710,6 @@ def _initialize(self): self.argument_variable = None def _update_result_type(self, var: ast_internal_classes.Name_Node): - """ For both SUM and PRODUCT, the result type depends on the input variable. """ @@ -750,8 +738,8 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n self.argument_variable = self.rvals[0] - par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, True) + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], self.count, new_func_body, self.scope_vars, + self.ast.structures, True) def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: @@ -759,22 +747,17 @@ def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_c lval=node.lval, op="=", rval=ast_internal_classes.Int_Literal_Node(value=self._result_init_value()), - line_number=node.line_number - ) + line_number=node.line_number) def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: - return ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=ast_internal_classes.BinOp_Node( - lval=node.lval, - op=self._result_update_op(), - rval=self.argument_variable, - line_number=node.line_number - ), - line_number=node.line_number - ) + return ast_internal_classes.BinOp_Node(lval=node.lval, + op="=", + rval=ast_internal_classes.BinOp_Node(lval=node.lval, + op=self._result_update_op(), + rval=self.argument_variable, + line_number=node.line_number), + line_number=node.line_number) class Sum(LoopBasedReplacement): @@ -834,7 +817,6 @@ def _initialize(self): self.cond = None def _update_result_type(self, var: ast_internal_classes.Name_Node): - """ For all functions, the result type is INTEGER. Theoretically, we should return LOGICAL for ANY and ALL, @@ -853,12 +835,10 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): array_node = self._parse_array(node, arg) if array_node is not None: self.first_array = array_node - self.cond = ast_internal_classes.BinOp_Node( - op="==", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - lval=self.first_array, - line_number=node.line_number - ) + self.cond = ast_internal_classes.BinOp_Node(op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + lval=self.first_array, + line_number=node.line_number) else: self.first_array, self.second_array, self.cond = self._parse_binary_op(node, arg) @@ -890,15 +870,12 @@ def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_c init_value = self._result_init_value() - return ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=ast_internal_classes.Int_Literal_Node(value=init_value), - line_number=node.line_number - ) + return ast_internal_classes.BinOp_Node(lval=node.lval, + op="=", + rval=ast_internal_classes.Int_Literal_Node(value=init_value), + line_number=node.line_number) def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: - """ For any, we check if the condition is true and then set the value to true For all, we check if the condition is NOT true and then set the value to false @@ -914,12 +891,10 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ # ) ]) - return ast_internal_classes.If_Stmt_Node( - cond=self._loop_condition(), - body=body_if, - body_else=ast_internal_classes.Execution_Part_Node(execution=[]), - line_number=node.line_number - ) + return ast_internal_classes.If_Stmt_Node(cond=self._loop_condition(), + body=body_if, + body_else=ast_internal_classes.Execution_Part_Node(execution=[]), + line_number=node.line_number) class Any(LoopBasedReplacement): @@ -951,12 +926,10 @@ def _result_init_value(self): return "0" def _result_loop_update(self, node: ast_internal_classes.FNode): - return ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(node.lval), - op="=", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=node.line_number - ) + return ast_internal_classes.BinOp_Node(lval=copy.deepcopy(node.lval), + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=node.line_number) def _loop_condition(self): return self.cond @@ -980,18 +953,13 @@ def _result_init_value(self): return "1" def _result_loop_update(self, node: ast_internal_classes.FNode): - return ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(node.lval), - op="=", - rval=ast_internal_classes.Int_Literal_Node(value="0"), - line_number=node.line_number - ) + return ast_internal_classes.BinOp_Node(lval=copy.deepcopy(node.lval), + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="0"), + line_number=node.line_number) def _loop_condition(self): - return ast_internal_classes.UnOp_Node( - op="not", - lval=self.cond - ) + return ast_internal_classes.UnOp_Node(op="not", lval=self.cond) @staticmethod def func_name() -> str: @@ -1014,18 +982,14 @@ def _result_init_value(self): return "0" def _result_loop_update(self, node: ast_internal_classes.FNode): - update = ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(node.lval), - op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=node.line_number - ) - return ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(node.lval), - op="=", - rval=update, - line_number=node.line_number - ) + update = ast_internal_classes.BinOp_Node(lval=copy.deepcopy(node.lval), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=node.line_number) + return ast_internal_classes.BinOp_Node(lval=copy.deepcopy(node.lval), + op="=", + rval=update, + line_number=node.line_number) def _loop_condition(self): return self.cond @@ -1042,7 +1006,6 @@ def _initialize(self): self.argument_variable = None def _update_result_type(self, var: ast_internal_classes.Name_Node): - """ For both MINVAL and MAXVAL, the result type depends on the input variable. """ @@ -1076,38 +1039,35 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n self.argument_variable = self.rvals[0] - par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, declaration=True) + par_Decl_Range_Finder(self.argument_variable, + self.loop_ranges, [], + self.count, + new_func_body, + self.scope_vars, + self.ast.structures, + declaration=True) def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: - return ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=self._result_init_value(self.argument_variable), - line_number=node.line_number - ) + return ast_internal_classes.BinOp_Node(lval=node.lval, + op="=", + rval=self._result_init_value(self.argument_variable), + line_number=node.line_number) def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: - cond = ast_internal_classes.BinOp_Node( - lval=self.argument_variable, - op=self._condition_op(), - rval=node.lval, - line_number=node.line_number - ) - body_if = ast_internal_classes.BinOp_Node( - lval=node.lval, - op="=", - rval=self.argument_variable, - line_number=node.line_number - ) - return ast_internal_classes.If_Stmt_Node( - cond=cond, - body=body_if, - body_else=ast_internal_classes.Execution_Part_Node(execution=[]), - line_number=node.line_number - ) + cond = ast_internal_classes.BinOp_Node(lval=self.argument_variable, + op=self._condition_op(), + rval=node.lval, + line_number=node.line_number) + body_if = ast_internal_classes.BinOp_Node(lval=node.lval, + op="=", + rval=self.argument_variable, + line_number=node.line_number) + return ast_internal_classes.If_Stmt_Node(cond=cond, + body=body_if, + body_else=ast_internal_classes.Execution_Part_Node(execution=[]), + line_number=node.line_number) class MinVal(LoopBasedReplacement): @@ -1175,6 +1135,7 @@ def func_name() -> str: class Merge(LoopBasedReplacement): + class Transformation(LoopBasedReplacementTransformation): def _initialize(self): @@ -1228,11 +1189,11 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): len_pardecls_first_array += len(pardecls) for ind in self.second_array.indices: pardecls = [i for i in mywalk(ind) if isinstance(i, ast_internal_classes.ParDecl_Node)] - len_pardecls_second_array += len(pardecls) + len_pardecls_second_array += len(pardecls) assert len_pardecls_first_array == len_pardecls_second_array if len_pardecls_first_array == 0: self.uses_scalars = True - else: + else: self.uses_scalars = False # Last argument is either an array or a binary op @@ -1241,12 +1202,10 @@ def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): if array_node is not None: self.mask_first_array = array_node - self.mask_cond = ast_internal_classes.BinOp_Node( - op="==", - rval=ast_internal_classes.Int_Literal_Node(value="1"), - lval=self.mask_first_array, - line_number=node.line_number - ) + self.mask_cond = ast_internal_classes.BinOp_Node(op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + lval=self.mask_first_array, + line_number=node.line_number) else: @@ -1259,15 +1218,26 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n self.destination_array = node.lval return - # The first main argument is an array -> this dictates loop boundaries # Other arrays, regardless if they appear as the second array or mask, need to have the same loop boundary. - par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, True, allow_scalars=True) + par_Decl_Range_Finder(self.first_array, + self.loop_ranges, [], + self.count, + new_func_body, + self.scope_vars, + self.ast.structures, + True, + allow_scalars=True) loop_ranges = [] - par_Decl_Range_Finder(self.second_array, loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, True, allow_scalars=True) + par_Decl_Range_Finder(self.second_array, + loop_ranges, [], + self.count, + new_func_body, + self.scope_vars, + self.ast.structures, + True, + allow_scalars=True) self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges) # parse destination @@ -1283,9 +1253,10 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n else: dims = len(array_decl.sizes) self.destination_array = ast_internal_classes.Array_Subscript_Node( - name=node.lval, parent=node.lval.parent, type='VOID', - indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims - ) + name=node.lval, + parent=node.lval.parent, + type='VOID', + indices=[ast_internal_classes.ParDecl_Node(type='ALL')] * dims) # type inference! this is necessary when the destination array is # not known exactly, e.g., in recursive calls. @@ -1296,19 +1267,31 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n array_decl.offsets = [1] * len(array_decl.sizes) array_decl.type = first_input.type - par_Decl_Range_Finder(self.destination_array, [], [], self.count, - new_func_body, self.scope_vars, self.ast.structures, True) + par_Decl_Range_Finder(self.destination_array, [], [], self.count, new_func_body, self.scope_vars, + self.ast.structures, True) if self.mask_first_array is not None: loop_ranges = [] - par_Decl_Range_Finder(self.mask_first_array, loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, True, allow_scalars=True) + par_Decl_Range_Finder(self.mask_first_array, + loop_ranges, [], + self.count, + new_func_body, + self.scope_vars, + self.ast.structures, + True, + allow_scalars=True) self._adjust_array_ranges(node, self.mask_first_array, self.loop_ranges, loop_ranges) if self.mask_second_array is not None: loop_ranges = [] - par_Decl_Range_Finder(self.mask_second_array, loop_ranges, [], self.count, new_func_body, - self.scope_vars, self.ast.structures, True, allow_scalars=True) + par_Decl_Range_Finder(self.mask_second_array, + loop_ranges, [], + self.count, + new_func_body, + self.scope_vars, + self.ast.structures, + True, + allow_scalars=True) self._adjust_array_ranges(node, self.mask_second_array, self.loop_ranges, loop_ranges) def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_internal_classes.BinOp_Node]: @@ -1318,40 +1301,29 @@ def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_i return None def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: - """ We check if the condition is true. If yes, then we write from the first array. Otherwise, we copy data from the second array. """ - copy_first = ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(self.destination_array), - op="=", - rval=self.first_array, - line_number=node.line_number - ) + copy_first = ast_internal_classes.BinOp_Node(lval=copy.deepcopy(self.destination_array), + op="=", + rval=self.first_array, + line_number=node.line_number) - copy_second = ast_internal_classes.BinOp_Node( - lval=copy.deepcopy(self.destination_array), - op="=", - rval=self.second_array, - line_number=node.line_number - ) + copy_second = ast_internal_classes.BinOp_Node(lval=copy.deepcopy(self.destination_array), + op="=", + rval=self.second_array, + line_number=node.line_number) - body_if = ast_internal_classes.Execution_Part_Node(execution=[ - copy_first - ]) + body_if = ast_internal_classes.Execution_Part_Node(execution=[copy_first]) - body_else = ast_internal_classes.Execution_Part_Node(execution=[ - copy_second - ]) + body_else = ast_internal_classes.Execution_Part_Node(execution=[copy_second]) - return ast_internal_classes.If_Stmt_Node( - cond=self.mask_cond, - body=body_if, - body_else=body_else, - line_number=node.line_number - ) + return ast_internal_classes.If_Stmt_Node(cond=self.mask_cond, + body=body_if, + body_else=body_else, + line_number=node.line_number) class IntrinsicSDFGTransformation(xf.SingleStateTransformation): @@ -1364,18 +1336,7 @@ def blas_dot(self, state: SDFGState, sdfg: SDFG): dot_libnode(None, sdfg, state, self.array1.data, self.array2.data, self.out.data) def blas_matmul(self, state: SDFGState, sdfg: SDFG): - gemm_libnode( - None, - sdfg, - state, - self.array1.data, - self.array2.data, - self.out.data, - 1.0, - 0.0, - False, - False - ) + gemm_libnode(None, sdfg, state, self.array1.data, self.array2.data, self.out.data, 1.0, 0.0, False, False) def transpose(self, state: SDFGState, sdfg: SDFG): @@ -1467,12 +1428,7 @@ def generate_scale(arg: ast_internal_classes.Call_Expr_Node): subroutine=False, ) - mult = ast_internal_classes.BinOp_Node( - op="*", - lval=x, - rval=rval, - line_number=line - ) + mult = ast_internal_classes.BinOp_Node(op="*", lval=x, rval=rval, line_number=line) # pack it into parentheses, just to be sure return ast_internal_classes.Parenthesis_Expr_Node(expr=mult) @@ -1604,7 +1560,7 @@ def replace_call(self, old_call: ast_internal_classes.Call_Expr_Node, new_call: raise NotImplementedError() def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): - + if not isinstance(binop_node.rval, ast_internal_classes.Call_Expr_Node): return binop_node @@ -1721,8 +1677,10 @@ def transformations(self) -> List[NodeTransformer]: def function_names() -> List[str]: # list of all functions that are created by initial transformation, before doing full replacement # this prevents other parser components from replacing our function calls with array subscription nodes - return [*list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()), *MathFunctions.temporary_functions(), - *DirectReplacement.temporary_functions()] + return [ + *list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()), *MathFunctions.temporary_functions(), + *DirectReplacement.temporary_functions() + ] @staticmethod def retained_function_names() -> List[str]: @@ -1755,7 +1713,7 @@ def replace_function_name(self, node: FASTNode) -> ast_internal_classes.Name_Nod "DATE_AND_TIME": "__dace_date_and_time", "RESHAPE": "__dace_reshape", } - + if func_name in replacements: return ast_internal_classes.Name_Node(name=replacements[func_name]) elif DirectReplacement.replacable_name(func_name): @@ -1796,13 +1754,18 @@ def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: if name.name in func_types: # FIXME: this will be progressively removed call_type = func_types[name.name] - return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line,subroutine=False) + return ast_internal_classes.Call_Expr_Node(name=name, + type=call_type, + args=args.args, + line_number=line, + subroutine=False) elif DirectReplacement.replacable(name.name): return DirectReplacement.replace(name.name, args, line, symbols) else: # We will do the actual type replacement later # To that end, we need to know the input types - but these we do not know at the moment. - return ast_internal_classes.Call_Expr_Node( - name=name, type="VOID", subroutine=False, - args=args.args, line_number=line - ) + return ast_internal_classes.Call_Expr_Node(name=name, + type="VOID", + subroutine=False, + args=args.args, + line_number=line) diff --git a/tests/fortran/prune_unused_children_test.py b/tests/fortran/prune_unused_children_test.py index 1e7d921930..50d69628ed 100644 --- a/tests/fortran/prune_unused_children_test.py +++ b/tests/fortran/prune_unused_children_test.py @@ -6,8 +6,11 @@ from fparser.two.parser import ParserFactory from fparser.two.utils import walk +from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, + ENTRY_POINT_OBJECT_TYPES, + find_name_of_node, + prune_unused_objects) from dace.frontend.fortran.fortran_parser import recursive_ast_improver -from dace.frontend.fortran.ast_desugaring import ENTRY_POINT_OBJECT_TYPES, find_name_of_node, prune_unused_objects from tests.fortran.fortran_test_helper import SourceCodeBuilder @@ -23,8 +26,8 @@ def parse_and_improve(sources: Dict[str, str]): def find_entrypoint_objects_named(ast: Program, name: str) -> List[ENTRY_POINT_OBJECT_TYPES]: objs: List[ENTRY_POINT_OBJECT_TYPES] = [] - for n in walk(ast, ENTRY_POINT_OBJECT_TYPES): - assert isinstance(n, ENTRY_POINT_OBJECT_TYPES) + for n in walk(ast, ENTRY_POINT_OBJECT): + assert isinstance(n, ENTRY_POINT_OBJECT) if not isinstance(n.parent, Program): continue if find_name_of_node(n) == name: From 2194c9f19f0ec89a69a9edf13c564fdab6d9d211 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 10:17:25 +0100 Subject: [PATCH 04/12] Formatting --- tests/fortran/advanced_optional_args_test.py | 11 +- tests/fortran/arg_extract_test.py | 19 +- tests/fortran/array_attributes_test.py | 19 +- tests/fortran/array_to_loop_offset_test.py | 66 ++++--- tests/fortran/ast_desugaring_test.py | 2 +- tests/fortran/ast_utils_test.py | 1 + tests/fortran/call_extract_test.py | 3 +- tests/fortran/cond_type_test.py | 6 +- tests/fortran/create_internal_ast_test.py | 181 ++++++++++-------- tests/fortran/dace_support_test.py | 1 - tests/fortran/empty_test.py | 3 +- tests/fortran/fortran_language_test.py | 21 +- tests/fortran/fortran_test_helper.py | 11 +- tests/fortran/future/fortran_class_test.py | 11 +- tests/fortran/global_test.py | 19 +- tests/fortran/ifcycle_test.py | 29 +-- tests/fortran/init_test.py | 23 +-- tests/fortran/intrinsic_all_test.py | 13 +- tests/fortran/intrinsic_any_test.py | 13 +- tests/fortran/intrinsic_basic_test.py | 40 ++-- tests/fortran/intrinsic_blas_test.py | 7 +- tests/fortran/intrinsic_bound_test.py | 60 ++++-- tests/fortran/intrinsic_count_test.py | 11 +- tests/fortran/intrinsic_math_test.py | 43 ++++- tests/fortran/intrinsic_merge_test.py | 38 ++-- tests/fortran/intrinsic_minmaxval_test.py | 9 +- tests/fortran/intrinsic_product_test.py | 4 + tests/fortran/intrinsic_sum_test.py | 7 +- tests/fortran/long_tasklet_test.py | 8 +- tests/fortran/missing_func_test.py | 13 +- tests/fortran/nested_array_test.py | 20 +- .../non-interactive/fortran_int_init_test.py | 9 +- .../fortran/non-interactive/function_test.py | 24 ++- .../fortran/non-interactive/pointers_test.py | 8 +- tests/fortran/non-interactive/view_test.py | 42 ++-- tests/fortran/offset_normalizer_test.py | 79 ++++---- tests/fortran/optional_args_test.py | 8 +- tests/fortran/parent_test.py | 6 +- tests/fortran/pointer_removal_test.py | 31 +-- tests/fortran/prune_test.py | 9 +- tests/fortran/prune_unused_children_test.py | 4 +- tests/fortran/ranges_test.py | 24 ++- tests/fortran/recursive_ast_improver_test.py | 6 +- tests/fortran/rename_test.py | 11 +- tests/fortran/scope_arrays_test.py | 1 + tests/fortran/struct_test.py | 7 +- tests/fortran/sum_to_loop_offset_test.py | 7 +- tests/fortran/tasklet_test.py | 5 +- tests/fortran/type_array_test.py | 38 ++-- tests/fortran/type_test.py | 77 +++++--- tests/fortran/while_test.py | 3 +- 51 files changed, 666 insertions(+), 445 deletions(-) diff --git a/tests/fortran/advanced_optional_args_test.py b/tests/fortran/advanced_optional_args_test.py index 860eeb74dd..6fee8b09d7 100644 --- a/tests/fortran/advanced_optional_args_test.py +++ b/tests/fortran/advanced_optional_args_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_optional_adv(): test_string = """ PROGRAM adv_intrinsic_optional_test_function @@ -73,9 +74,12 @@ def test_fortran_frontend_optional_adv(): END SUBROUTINE get_indices_c """ - sources={} - sources["adv_intrinsic_optional_test_function"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_optional_test_function", True,sources=sources) + sources = {} + sources["adv_intrinsic_optional_test_function"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "intrinsic_optional_test_function", + True, + sources=sources) sdfg.simplify(verbose=True) sdfg.compile() @@ -87,6 +91,7 @@ def test_fortran_frontend_optional_adv(): assert res[0] == 5 assert res2[0] == 0 + if __name__ == "__main__": test_fortran_frontend_optional_adv() diff --git a/tests/fortran/arg_extract_test.py b/tests/fortran/arg_extract_test.py index f0085d1f1a..d4bd313cb7 100644 --- a/tests/fortran/arg_extract_test.py +++ b/tests/fortran/arg_extract_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_arg_extract(): test_string = """ PROGRAM arg_extract @@ -32,13 +33,11 @@ def test_fortran_frontend_arg_extract(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - - input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) - assert np.allclose(res, [3,7]) + assert np.allclose(res, [3, 7]) def test_fortran_frontend_arg_extract2(): @@ -68,13 +67,11 @@ def test_fortran_frontend_arg_extract2(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract2", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - - input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) - assert np.allclose(res, [3,7]) + assert np.allclose(res, [3, 7]) def test_fortran_frontend_arg_extract3(): @@ -106,13 +103,11 @@ def test_fortran_frontend_arg_extract3(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract3", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - - input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) - assert np.allclose(res, [10,52]) + assert np.allclose(res, [10, 52]) def test_fortran_frontend_arg_extract4(): @@ -148,13 +143,12 @@ def test_fortran_frontend_arg_extract4(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract4", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - - input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) - assert np.allclose(res, [10,52]) + assert np.allclose(res, [10, 52]) + if __name__ == "__main__": @@ -162,4 +156,3 @@ def test_fortran_frontend_arg_extract4(): #test_fortran_frontend_arg_extract2() test_fortran_frontend_arg_extract3() test_fortran_frontend_arg_extract4() - diff --git a/tests/fortran/array_attributes_test.py b/tests/fortran/array_attributes_test.py index 115946d703..6746b78c95 100644 --- a/tests/fortran/array_attributes_test.py +++ b/tests/fortran/array_attributes_test.py @@ -221,7 +221,8 @@ def test_fortran_frontend_array_offset_symbol(): def test_fortran_frontend_array_arbitrary(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d, arrsize, arrsize2) integer :: arrsize integer :: arrsize2 @@ -269,7 +270,8 @@ def test_fortran_frontend_array_arbitrary_attribute(): def test_fortran_frontend_array_arbitrary_attribute2(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ module lib contains subroutine main(d, d2) @@ -296,9 +298,16 @@ def test_fortran_frontend_array_arbitrary_attribute2(): arrsize4 = 7 a = np.full([arrsize, arrsize2], 42, order="F", dtype=np.float64) b = np.full([arrsize3, arrsize4], 42, order="F", dtype=np.float64) - sdfg(d=a, __f2dace_A_d_d_0_s_0=arrsize, __f2dace_A_d_d_1_s_1=arrsize2, - d2=b, __f2dace_A_d2_d_0_s_2=arrsize3, __f2dace_A_d2_d_1_s_3=arrsize4, - arrsize=arrsize, arrsize2=arrsize2, arrsize3=arrsize3, arrsize4=arrsize4) + sdfg(d=a, + __f2dace_A_d_d_0_s_0=arrsize, + __f2dace_A_d_d_1_s_1=arrsize2, + d2=b, + __f2dace_A_d2_d_0_s_2=arrsize3, + __f2dace_A_d2_d_1_s_3=arrsize4, + arrsize=arrsize, + arrsize2=arrsize2, + arrsize3=arrsize3, + arrsize4=arrsize4) assert a[1, 1] == arrsize assert a[1, 2] == arrsize2 assert a[1, 3] == arrsize3 diff --git a/tests/fortran/array_to_loop_offset_test.py b/tests/fortran/array_to_loop_offset_test.py index fe16ed1418..38b4589c71 100644 --- a/tests/fortran/array_to_loop_offset_test.py +++ b/tests/fortran/array_to_loop_offset_test.py @@ -4,6 +4,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser + def test_fortran_frontend_arr2loop_without_offset(): """ Tests that the generated array map correctly handles offsets. @@ -36,11 +37,12 @@ def test_fortran_frontend_arr2loop_without_offset(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,9], 42, order="F", dtype=np.float64) + a = np.full([5, 9], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(1,6): - for j in range(1,4): - assert a[i-1, j-1] == i * 2 + for i in range(1, 6): + for j in range(1, 4): + assert a[i - 1, j - 1] == i * 2 + def test_fortran_frontend_arr2loop_1d_offset(): """ @@ -73,8 +75,9 @@ def test_fortran_frontend_arr2loop_1d_offset(): a = np.full([6], 42, order="F", dtype=np.float64) sdfg(d=a) assert a[0] == 42 - for i in range(2,7): - assert a[i-1] == 5 + for i in range(2, 7): + assert a[i - 1] == 5 + def test_fortran_frontend_arr2loop_2d_offset(): """ @@ -108,11 +111,12 @@ def test_fortran_frontend_arr2loop_2d_offset(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,9], 42, order="F", dtype=np.float64) + a = np.full([5, 9], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(1,6): - for j in range(7,10): - assert a[i-1, j-1] == i * 2 + for i in range(1, 6): + for j in range(7, 10): + assert a[i - 1, j - 1] == i * 2 + def test_fortran_frontend_arr2loop_2d_offset2(): """ @@ -143,22 +147,23 @@ def test_fortran_frontend_arr2loop_2d_offset2(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,9], 42, order="F", dtype=np.float64) + a = np.full([5, 9], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(1,6): - for j in range(7,10): - assert a[i-1, j-1] == 43 + for i in range(1, 6): + for j in range(7, 10): + assert a[i - 1, j - 1] == 43 sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) sdfg.simplify(verbose=True) sdfg.compile() - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(0,5): - for j in range(0,3): + for i in range(0, 5): + for j in range(0, 3): assert a[i, j] == 43 + def test_fortran_frontend_arr2loop_2d_offset3(): """ Tests that the generated array map correctly handles offsets. @@ -188,34 +193,35 @@ def test_fortran_frontend_arr2loop_2d_offset3(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,9], 42, order="F", dtype=np.float64) + a = np.full([5, 9], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(2,4): - for j in range(7,9): - assert a[i-1, j-1] == 43 - for j in range(9,10): - assert a[i-1, j-1] == 42 + for i in range(2, 4): + for j in range(7, 9): + assert a[i - 1, j - 1] == 43 + for j in range(9, 10): + assert a[i - 1, j - 1] == 42 for i in [1, 5]: - for j in range(7,10): - assert a[i-1, j-1] == 42 + for j in range(7, 10): + assert a[i - 1, j - 1] == 42 sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) sdfg.simplify(verbose=True) sdfg.compile() - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(1,4): - for j in range(0,2): + for i in range(1, 4): + for j in range(0, 2): assert a[i, j] == 43 - for j in range(2,3): + for j in range(2, 3): assert a[i, j] == 42 for i in [0, 4]: - for j in range(0,3): + for j in range(0, 3): assert a[i, j] == 42 + if __name__ == "__main__": test_fortran_frontend_arr2loop_1d_offset() diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index 909170dda5..5392d4ece1 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -1196,7 +1196,7 @@ def test_globally_unique_names(): """).check_with_gfortran().get() ast = parse_and_improve(sources) ast = correct_for_function_calls(ast) - ast = assign_globally_unique_subprogram_names(ast, {('main',)}) + ast = assign_globally_unique_subprogram_names(ast, {('main', )}) ast = assign_globally_unique_variable_names(ast, set()) got = ast.tofortran() diff --git a/tests/fortran/ast_utils_test.py b/tests/fortran/ast_utils_test.py index 4ab7b87f35..d7a7031f47 100644 --- a/tests/fortran/ast_utils_test.py +++ b/tests/fortran/ast_utils_test.py @@ -6,6 +6,7 @@ def test_floatlit2string(): + def parse(fl: str) -> float: t = TaskletWriter([], []) # The parameters won't matter. return t.floatlit2string(Real_Literal_Node(value=fl)) diff --git a/tests/fortran/call_extract_test.py b/tests/fortran/call_extract_test.py index c004083f31..b8b09cb283 100644 --- a/tests/fortran/call_extract_test.py +++ b/tests/fortran/call_extract_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_call_extract(): test_string = """ PROGRAM intrinsic_call_extract @@ -27,7 +28,7 @@ def test_fortran_frontend_call_extract(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_call_extract", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - + input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) diff --git a/tests/fortran/cond_type_test.py b/tests/fortran/cond_type_test.py index a395047db1..89de808c36 100644 --- a/tests/fortran/cond_type_test.py +++ b/tests/fortran/cond_type_test.py @@ -52,9 +52,9 @@ def test_fortran_frontend_cond_type(): endif END SUBROUTINE cond_type_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) diff --git a/tests/fortran/create_internal_ast_test.py b/tests/fortran/create_internal_ast_test.py index 47193f3445..267eaf05e7 100644 --- a/tests/fortran/create_internal_ast_test.py +++ b/tests/fortran/create_internal_ast_test.py @@ -36,16 +36,18 @@ def test_minimal(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'main_program': M(Main_Program_Node), - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - ], - 'structures': M(Structures, has_empty_attr={'structures'}) - }, has_empty_attr={'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'main_program': M(Main_Program_Node), + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, + has_empty_attr={'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) m.check(prog) @@ -72,19 +74,22 @@ def test_standalone_subroutines(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('not_fun'), - 'args': [M.NAMED('d'), M.NAMED('val')], - }), - ], - 'structures': M(Structures, has_empty_attr={'structures'}) - }, has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('not_fun'), + 'args': [M.NAMED('d'), M.NAMED('val')], + }), + ], + 'structures': + M(Structures, has_empty_attr={'structures'}) + }, + has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) m.check(prog) @@ -123,22 +128,30 @@ def test_subroutines_from_module(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'main_program': M(Main_Program_Node), - 'modules': [M(Module_Node, has_attr={ - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('not_fun'), - 'args': [M.NAMED('d'), M.NAMED('val')], - }), - ], - }, has_empty_attr={'function_definitions', 'interface_blocks'})], - 'structures': M(Structures, has_empty_attr={'structures'}) - }, has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'main_program': + M(Main_Program_Node), + 'modules': [ + M(Module_Node, + has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('not_fun'), + 'args': [M.NAMED('d'), M.NAMED('val')], + }), + ], + }, + has_empty_attr={'function_definitions', 'interface_blocks'}) + ], + 'structures': + M(Structures, has_empty_attr={'structures'}) + }, + has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) m.check(prog) @@ -161,15 +174,17 @@ def test_subroutine_with_local_variable(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - ], - 'structures': M(Structures, has_empty_attr={'structures'}) - }, has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': M(Structures, has_empty_attr={'structures'}) + }, + has_empty_attr={'main_program', 'function_definitions', 'modules', 'placeholders', 'placeholders_offsets'}) m.check(prog) @@ -207,18 +222,26 @@ def test_subroutine_contains_function(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'main_program': M(Main_Program_Node), - 'modules': [M(Module_Node, has_attr={ - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - ], - }, has_empty_attr={'function_definitions', 'interface_blocks'})], - 'structures': M(Structures, has_empty_attr={'structures'}) - }, has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'main_program': + M(Main_Program_Node), + 'modules': [ + M(Module_Node, + has_attr={ + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + }, + has_empty_attr={'function_definitions', 'interface_blocks'}) + ], + 'structures': + M(Structures, has_empty_attr={'structures'}) + }, + has_empty_attr={'function_definitions', 'subroutine_definitions', 'placeholders', 'placeholders_offsets'}) m.check(prog) # TODO: We cannot handle during the internal AST construction (it works just fine before during parsing etc.) when a @@ -264,19 +287,27 @@ def test_module_contains_types(): # Verify assert not iast.fortran_intrinsics().transformations() - m = M(Program_Node, has_attr={ - 'main_program': M(Main_Program_Node), - 'modules': [M(Module_Node, has_attr={ - 'specification_part': M(Specification_Part_Node, {'typedecls': M.IGNORE(1)}) - }, has_empty_attr={'function_definitions', 'interface_blocks'})], - 'subroutine_definitions': [ - M(Subroutine_Subprogram_Node, { - 'name': M.NAMED('fun'), - 'args': [M.NAMED('d')], - }), - ], - 'structures': M(Structures, { - 'structures': {'used_type': M(Structure)}, - }) - }, has_empty_attr={'function_definitions', 'placeholders', 'placeholders_offsets'}) + m = M(Program_Node, + has_attr={ + 'main_program': + M(Main_Program_Node), + 'modules': [ + M(Module_Node, + has_attr={'specification_part': M(Specification_Part_Node, {'typedecls': M.IGNORE(1)})}, + has_empty_attr={'function_definitions', 'interface_blocks'}) + ], + 'subroutine_definitions': [ + M(Subroutine_Subprogram_Node, { + 'name': M.NAMED('fun'), + 'args': [M.NAMED('d')], + }), + ], + 'structures': + M(Structures, { + 'structures': { + 'used_type': M(Structure) + }, + }) + }, + has_empty_attr={'function_definitions', 'placeholders', 'placeholders_offsets'}) m.check(prog) diff --git a/tests/fortran/dace_support_test.py b/tests/fortran/dace_support_test.py index 096ea25a18..54d9f229f6 100644 --- a/tests/fortran/dace_support_test.py +++ b/tests/fortran/dace_support_test.py @@ -7,7 +7,6 @@ import numpy as np import pytest - from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser from fparser.two.symbol_table import SymbolTable diff --git a/tests/fortran/empty_test.py b/tests/fortran/empty_test.py index 7e07aa09df..9ec83f4d47 100644 --- a/tests/fortran/empty_test.py +++ b/tests/fortran/empty_test.py @@ -17,7 +17,8 @@ def test_fortran_frontend_empty(): fun_with_no_arguments = (process_mpi_all_size <= 1) end function fun_with_no_arguments end module module_mpi -""").add_file(""" +""").add_file( + """ subroutine main(d) use module_mpi, only: fun_with_no_arguments double precision d(2, 3) diff --git a/tests/fortran/fortran_language_test.py b/tests/fortran/fortran_language_test.py index e784503e57..cd64301241 100644 --- a/tests/fortran/fortran_language_test.py +++ b/tests/fortran/fortran_language_test.py @@ -10,7 +10,8 @@ def test_fortran_frontend_real_kind_selector(): """ Tests that the size intrinsics are correctly parsed and translated to DaCe. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) implicit none integer, parameter :: JPRB = selected_real_kind(13, 300) @@ -35,7 +36,8 @@ def test_fortran_frontend_if1(): """ Tests that the if/else construct is correctly parsed and translated to DaCe. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) implicit none double precision d(3, 4, 5), ZFAC(10) @@ -66,7 +68,8 @@ def test_fortran_frontend_loop1(): """ Tests that the loop construct is correctly parsed and translated to DaCe. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) logical d(3, 4, 5), ZFAC(10) integer :: a, JK, JL, JM @@ -97,7 +100,8 @@ def test_fortran_frontend_function_statement1(): """ Tests that the function statement are correctly removed recursively. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) double precision d(3, 4, 5) double precision :: PTARE, RTT(2), FOEDELTA, FOELDCP @@ -126,7 +130,8 @@ def test_fortran_frontend_pow1(): """ Tests that the power intrinsic is correctly parsed and translated to DaCe. (should become a*a) """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) implicit none double precision d(3, 4, 5) @@ -152,7 +157,8 @@ def test_fortran_frontend_pow2(): """ Tests that the power intrinsic is correctly parsed and translated to DaCe (this time it's p sqrt p). """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) implicit none double precision d(3, 4, 5) @@ -178,7 +184,8 @@ def test_fortran_frontend_sign1(): """ Tests that the sign intrinsic is correctly parsed and translated to DaCe. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine main(d) implicit none double precision d(3, 4, 5) diff --git a/tests/fortran/fortran_test_helper.py b/tests/fortran/fortran_test_helper.py index a82f392f35..14ef321f46 100644 --- a/tests/fortran/fortran_test_helper.py +++ b/tests/fortran/fortran_test_helper.py @@ -125,8 +125,7 @@ def __init__(self, has_attr: Optional[Dict[str, Union["FortranASTMatcher", List["FortranASTMatcher"]]]] = None, has_value: Optional[str] = None): # TODO: Include Set[Self] to `has_children` type? - assert not ((set() if has_attr is None else has_attr.keys()) - & {'children'}) + assert not ((set() if has_attr is None else has_attr.keys()) & {'children'}) self.is_type = is_type self.has_children = has_children self.has_attr = has_attr @@ -220,7 +219,8 @@ class InternalASTMatcher: def __init__(self, is_type: Optional[Type] = None, - has_attr: Optional[Dict[str, Union["InternalASTMatcher", List["InternalASTMatcher"], Dict[str, "InternalASTMatcher"]]]] = None, + has_attr: Optional[Dict[str, Union["InternalASTMatcher", List["InternalASTMatcher"], + Dict[str, "InternalASTMatcher"]]]] = None, has_empty_attr: Optional[Collection[str]] = None, has_value: Optional[str] = None): # TODO: Include Set[Self] to `has_children` type? @@ -276,10 +276,7 @@ def NAMED(cls, name: str): return cls(Name_Node, {'name': cls(has_value=name)}) -def create_singular_sdfg_from_string( - sources: Dict[str, str], - entry_point: str, - normalize_offsets: bool = True): +def create_singular_sdfg_from_string(sources: Dict[str, str], entry_point: str, normalize_offsets: bool = True): entry_point = entry_point.split('.') cfg = ParseConfig(main=sources['main.f90'], sources=sources, entry_points=tuple(entry_point)) diff --git a/tests/fortran/future/fortran_class_test.py b/tests/fortran/future/fortran_class_test.py index 7e6ab50577..527d307135 100644 --- a/tests/fortran/future/fortran_class_test.py +++ b/tests/fortran/future/fortran_class_test.py @@ -7,7 +7,6 @@ import numpy as np import pytest - from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser from fparser.two.symbol_table import SymbolTable @@ -19,8 +18,6 @@ import dace.frontend.fortran.ast_internal_classes as ast_internal_classes - - def test_fortran_frontend_class(): """ Tests that whether clasess are translated correctly @@ -87,7 +84,7 @@ def test_fortran_frontend_class(): d(1)=p_pat%n_pnts END SUBROUTINE class_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "class_test",False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "class_test", False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -97,7 +94,7 @@ def test_fortran_frontend_class(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() sdfg.compile() @@ -108,10 +105,6 @@ def test_fortran_frontend_class(): # assert (d[0] == 400) - if __name__ == "__main__": - - test_fortran_frontend_class() - diff --git a/tests/fortran/global_test.py b/tests/fortran/global_test.py index 82dcc46db0..b0304d3dc3 100644 --- a/tests/fortran/global_test.py +++ b/tests/fortran/global_test.py @@ -36,9 +36,9 @@ def test_fortran_frontend_global(): """ - sources={} - sources["global_test"]=test_string - sources["global_test_module_subroutine.f90"]=""" + sources = {} + sources["global_test"] = test_string + sources["global_test_module_subroutine.f90"] = """ MODULE global_test_module_subroutine CONTAINS @@ -65,7 +65,7 @@ def test_fortran_frontend_global(): END SUBROUTINE global_test_function END MODULE global_test_module_subroutine """ - sources["global_test_module.f90"]=""" + sources["global_test_module.f90"] = """ MODULE global_test_module IMPLICIT NONE TYPE simple_type @@ -76,8 +76,8 @@ def test_fortran_frontend_global(): integer outside_init=1 END MODULE global_test_module """ - - sources["nested_one.f90"]=""" + + sources["nested_one.f90"] = """ MODULE nested_one IMPLICIT NONE CONTAINS @@ -93,7 +93,7 @@ def test_fortran_frontend_global(): END MODULE nested_one """ - sources["nested_two.f90"]=""" + sources["nested_two.f90"] = """ MODULE nested_two IMPLICIT NONE CONTAINS @@ -108,17 +108,18 @@ def test_fortran_frontend_global(): END MODULE nested_two """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "global_test",sources=sources,normalize_offsets=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "global_test", sources=sources, normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.save('test.sdfg') a = np.full([4], 42, order="F", dtype=np.float64) - a2 = np.full([4,4,4], 42, order="F", dtype=np.float64) + a2 = np.full([4, 4, 4], 42, order="F", dtype=np.float64) #TODO Add validation - but we need python structs for this. #sdfg(d=a,a=a2) #assert (a[0] == 42) #assert (a[1] == 5.5) #assert (a[2] == 42) + if __name__ == "__main__": test_fortran_frontend_global() diff --git a/tests/fortran/ifcycle_test.py b/tests/fortran/ifcycle_test.py index ae7a943721..444fee1810 100644 --- a/tests/fortran/ifcycle_test.py +++ b/tests/fortran/ifcycle_test.py @@ -41,9 +41,13 @@ def test_fortran_frontend_if_cycle(): END SUBROUTINE if_cycle_test_function """ - sources={} - sources["if_cycle"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_cycle",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sources = {} + sources["if_cycle"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "if_cycle", + normalize_offsets=True, + multiple_sdfgs=False, + sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -52,7 +56,6 @@ def test_fortran_frontend_if_cycle(): assert (a[2] == 5.5) - def test_fortran_frontend_if_nested_cycle(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -91,15 +94,19 @@ def test_fortran_frontend_if_nested_cycle(): END SUBROUTINE if_nested_cycle_test_function """ - sources={} - sources["if_nested_cycle"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_nested_cycle",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sources = {} + sources["if_nested_cycle"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "if_nested_cycle", + normalize_offsets=True, + multiple_sdfgs=False, + sources=sources) sdfg.simplify(verbose=True) - a = np.full([4,4], 42, order="F", dtype=np.float64) + a = np.full([4, 4], 42, order="F", dtype=np.float64) sdfg(d=a) - assert (a[0,0] == 42) - assert (a[1,0] == 6.5) - assert (a[2,0] == 42) + assert (a[0, 0] == 42) + assert (a[1, 0] == 6.5) + assert (a[2, 0] == 42) if __name__ == "__main__": diff --git a/tests/fortran/init_test.py b/tests/fortran/init_test.py index f23446bf14..e710bf92b1 100644 --- a/tests/fortran/init_test.py +++ b/tests/fortran/init_test.py @@ -33,9 +33,9 @@ def test_fortran_frontend_init(): """ - sources={} - sources["init_test"]=test_string - sources["init_test_module_subroutine.f90"]=""" + sources = {} + sources["init_test"] = test_string + sources["init_test_module_subroutine.f90"] = """ MODULE init_test_module_subroutine CONTAINS SUBROUTINE init_test_function(d) @@ -49,16 +49,16 @@ def test_fortran_frontend_init(): END SUBROUTINE init_test_function END MODULE init_test_module_subroutine """ - sources["init_test_module.f90"]=""" + sources["init_test_module.f90"] = """ MODULE init_test_module IMPLICIT NONE REAL outside_init=EPSILON(1.0) END MODULE init_test_module """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test",sources=sources) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) - sdfg(d=a,outside_init=0) + sdfg(d=a, outside_init=0) assert (a[0] == 42) assert (a[1] == 5.5) assert (a[2] == 42) @@ -78,9 +78,9 @@ def test_fortran_frontend_init2(): """ - sources={} - sources["init2_test"]=test_string - sources["init2_test_module_subroutine.f90"]=""" + sources = {} + sources["init2_test"] = test_string + sources["init2_test_module_subroutine.f90"] = """ MODULE init2_test_module_subroutine CONTAINS SUBROUTINE init2_test_function(d) @@ -93,13 +93,13 @@ def test_fortran_frontend_init2(): END SUBROUTINE init2_test_function END MODULE init2_test_module_subroutine """ - sources["init2_test_module.f90"]=""" + sources["init2_test_module.f90"] = """ MODULE init2_test_module IMPLICIT NONE REAL, PARAMETER :: TORUS_MAX_LAT = 4.0 / 18.0 * ATAN(1.0) END MODULE init2_test_module """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "init2_test",sources=sources) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init2_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a, torus_max_lat=4.0 / 18.0 * np.arctan(1.0)) @@ -107,6 +107,7 @@ def test_fortran_frontend_init2(): assert (a[1] == 5.674532920122147) assert (a[2] == 42) + if __name__ == "__main__": test_fortran_frontend_init() diff --git a/tests/fortran/intrinsic_all_test.py b/tests/fortran/intrinsic_all_test.py index dc0c76c677..d16f96f125 100644 --- a/tests/fortran/intrinsic_all_test.py +++ b/tests/fortran/intrinsic_all_test.py @@ -110,6 +110,7 @@ def test_fortran_frontend_all_array_comparison(): for val in res: assert val == False + def test_fortran_frontend_all_array_scalar_comparison(): test_string = """ PROGRAM intrinsic_all_test @@ -160,6 +161,7 @@ def test_fortran_frontend_all_array_scalar_comparison(): sdfg(first=first, res=res) assert list(res) == [0, 0, 0, 0, 0, 0, 1] + def test_fortran_frontend_all_array_comparison_wrong_subset(): test_string = """ PROGRAM intrinsic_all_test @@ -183,6 +185,7 @@ def test_fortran_frontend_all_array_comparison_wrong_subset(): with pytest.raises(TypeError): fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test") + def test_fortran_frontend_all_array_2d(): test_string = """ PROGRAM intrinsic_all_test @@ -209,14 +212,15 @@ def test_fortran_frontend_all_array_2d(): d = np.full(sizes, True, order="F", dtype=np.int32) res = np.full([2], 42, order="F", dtype=np.int32) - d[2,2] = False + d[2, 2] = False sdfg(d=d, res=res) assert res[0] == False - d[2,2] = True + d[2, 2] = True sdfg(d=d, res=res) assert res[0] == True + def test_fortran_frontend_all_array_comparison_2d(): test_string = """ PROGRAM intrinsic_all_test @@ -251,7 +255,7 @@ def test_fortran_frontend_all_array_comparison_2d(): sizes = [5, 4] first = np.full(sizes, 1, order="F", dtype=np.int32) second = np.full(sizes, 1, order="F", dtype=np.int32) - second[2,2] = 2 + second[2, 2] = 2 res = np.full([7], 0, order="F", dtype=np.int32) sdfg(first=first, second=second, res=res) @@ -264,6 +268,7 @@ def test_fortran_frontend_all_array_comparison_2d(): for val in res: assert val == True + def test_fortran_frontend_all_array_comparison_2d_subset(): test_string = """ PROGRAM intrinsic_all_test @@ -306,6 +311,7 @@ def test_fortran_frontend_all_array_comparison_2d_subset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 1] + def test_fortran_frontend_all_array_comparison_2d_subset_offset(): test_string = """ PROGRAM intrinsic_all_test @@ -348,6 +354,7 @@ def test_fortran_frontend_all_array_comparison_2d_subset_offset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 1] + if __name__ == "__main__": test_fortran_frontend_all_array() diff --git a/tests/fortran/intrinsic_any_test.py b/tests/fortran/intrinsic_any_test.py index 49d0b5c12c..359d0ce7ac 100644 --- a/tests/fortran/intrinsic_any_test.py +++ b/tests/fortran/intrinsic_any_test.py @@ -112,6 +112,7 @@ def test_fortran_frontend_any_array_comparison(): for val in res: assert val == False + def test_fortran_frontend_any_array_scalar_comparison(): test_string = """ PROGRAM intrinsic_any_test @@ -163,6 +164,7 @@ def test_fortran_frontend_any_array_scalar_comparison(): sdfg(first=first, res=res) assert list(res) == [1, 1, 0, 1, 1, 1, 1] + def test_fortran_frontend_any_array_comparison_wrong_subset(): test_string = """ PROGRAM intrinsic_any_test @@ -186,6 +188,7 @@ def test_fortran_frontend_any_array_comparison_wrong_subset(): with pytest.raises(TypeError): fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + def test_fortran_frontend_any_array_2d(): test_string = """ PROGRAM intrinsic_any_test @@ -212,14 +215,15 @@ def test_fortran_frontend_any_array_2d(): d = np.full(sizes, False, order="F", dtype=np.int32) res = np.full([2], 42, order="F", dtype=np.int32) - d[2,2] = True + d[2, 2] = True sdfg(d=d, res=res) assert res[0] == True - d[2,2] = False + d[2, 2] = False sdfg(d=d, res=res) assert res[0] == False + def test_fortran_frontend_any_array_comparison_2d(): test_string = """ PROGRAM intrinsic_any_test @@ -254,7 +258,7 @@ def test_fortran_frontend_any_array_comparison_2d(): sizes = [5, 4] first = np.full(sizes, 1, order="F", dtype=np.int32) second = np.full(sizes, 2, order="F", dtype=np.int32) - second[2,2] = 1 + second[2, 2] = 1 res = np.full([7], 0, order="F", dtype=np.int32) sdfg(first=first, second=second, res=res) @@ -267,6 +271,7 @@ def test_fortran_frontend_any_array_comparison_2d(): for val in res: assert val == False + def test_fortran_frontend_any_array_comparison_2d_subset(): test_string = """ PROGRAM intrinsic_any_test @@ -309,6 +314,7 @@ def test_fortran_frontend_any_array_comparison_2d_subset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 1] + def test_fortran_frontend_any_array_comparison_2d_subset_offset(): test_string = """ PROGRAM intrinsic_any_test @@ -351,6 +357,7 @@ def test_fortran_frontend_any_array_comparison_2d_subset_offset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 1] + if __name__ == "__main__": test_fortran_frontend_any_array() diff --git a/tests/fortran/intrinsic_basic_test.py b/tests/fortran/intrinsic_basic_test.py index 4a2e10d8d6..8b305b7a3b 100644 --- a/tests/fortran/intrinsic_basic_test.py +++ b/tests/fortran/intrinsic_basic_test.py @@ -6,6 +6,7 @@ from dace.frontend.fortran import fortran_parser from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder + def test_fortran_frontend_bit_size(): test_string = """ PROGRAM intrinsic_math_test_bit_size @@ -39,6 +40,7 @@ def test_fortran_frontend_bit_size(): assert np.allclose(res, [32, 32, 32, 64]) + def test_fortran_frontend_bit_size_symbolic(): test_string = """ PROGRAM intrinsic_math_test_bit_size @@ -81,16 +83,17 @@ def test_fortran_frontend_bit_size_symbolic(): size3 = 7 res = np.full([size], 42, order="F", dtype=np.int32) res2 = np.full([size, size2, size3], 42, order="F", dtype=np.int32) - res3 = np.full([size+size2, size2*5, size3 + size*size2], 42, order="F", dtype=np.int32) + res3 = np.full([size + size2, size2 * 5, size3 + size * size2], 42, order="F", dtype=np.int32) sdfg(res=res, res2=res2, res3=res3, arrsize=size, arrsize2=size2, arrsize3=size3) assert res[0] == size - assert res[1] == size*size2*size3 - assert res[2] == (size + size2) * (size2 * 5) * (size3 + size2*size) - assert res[3] == size * 2 + assert res[1] == size * size2 * size3 + assert res[2] == (size + size2) * (size2 * 5) * (size3 + size2 * size) + assert res[3] == size * 2 assert res[4] == res[0] * res[1] * res[2] assert res[5] == size + size2 + size3 - assert res[6] == size + size2 + size2*5 + size3 + size*size2 + assert res[6] == size + size2 + size2 * 5 + size3 + size * size2 + def test_fortran_frontend_size_arbitrary(): test_string = """ @@ -113,18 +116,23 @@ def test_fortran_frontend_size_arbitrary(): END SUBROUTINE intrinsic_basic_size_arbitrary_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_basic_size_arbitrary_test", True,) + sdfg = fortran_parser.create_sdfg_from_string( + test_string, + "intrinsic_basic_size_arbitrary_test", + True, + ) sdfg.simplify(verbose=True) sdfg.compile() size = 7 size2 = 5 res = np.full([size, size2], 42, order="F", dtype=np.int32) - sdfg(res=res,arrsize=size,arrsize2=size2) + sdfg(res=res, arrsize=size, arrsize2=size2) + + assert res[0, 0] == size * size2 + assert res[1, 0] == size + assert res[2, 0] == size2 - assert res[0,0] == size*size2 - assert res[1,0] == size - assert res[2,0] == size2 def test_fortran_frontend_present(): test_string = """ @@ -212,6 +220,7 @@ def test_fortran_frontend_present(): assert res[0] == 1 assert res2[0] == 0 + def test_fortran_frontend_bitwise_ops(): sources, main = SourceCodeBuilder().add_file(""" SUBROUTINE bitwise_ops(input, res) @@ -246,12 +255,13 @@ def test_fortran_frontend_bitwise_ops(): input = np.full([size], 42, order="F", dtype=np.int32) res = np.full([size], 42, order="F", dtype=np.int32) - input = [32, 32, 33, 1073741825, 53, 530, 12, 1, 128, 1073741824, 12 ] + input = [32, 32, 33, 1073741825, 53, 530, 12, 1, 128, 1073741824, 12] sdfg(input=input, res=res) assert np.allclose(res, [33, 1073741856, 32, 1, 10, 1010, 384, 1073741824, 4, 1, 12]) + def test_fortran_frontend_bitwise_ops2(): sources, main = SourceCodeBuilder().add_file(""" SUBROUTINE bitwise_ops(input, res) @@ -278,12 +288,13 @@ def test_fortran_frontend_bitwise_ops2(): input = np.full([size], 42, order="F", dtype=np.int32) res = np.full([size], 42, order="F", dtype=np.int32) - input = [2147483647, 16, 3, 31, 30, 630] + input = [2147483647, 16, 3, 31, 30, 630] sdfg(input=input, res=res) assert np.allclose(res, [0, 16, 1, 0, 30, 78]) + def test_fortran_frontend_allocated(): # FIXME: this pattern is generally not supported. # this needs an update once defered allocs are merged @@ -317,12 +328,14 @@ def test_fortran_frontend_allocated(): assert np.allclose(res, [0, 1, 0]) + def test_fortran_frontend_allocated_nested(): # FIXME: this pattern is generally not supported. # this needs an update once defered allocs are merged - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE allocated_test_interface INTERFACE SUBROUTINE allocated_test_nested(data, res) @@ -370,6 +383,7 @@ def test_fortran_frontend_allocated_nested(): assert np.allclose(res, [0, 1, 0]) + if __name__ == "__main__": test_fortran_frontend_bit_size() diff --git a/tests/fortran/intrinsic_blas_test.py b/tests/fortran/intrinsic_blas_test.py index 2a04c7e1f3..c3ec668ef0 100644 --- a/tests/fortran/intrinsic_blas_test.py +++ b/tests/fortran/intrinsic_blas_test.py @@ -59,6 +59,7 @@ def test_fortran_frontend_dot_range(): sdfg(arg1=arg1, arg2=arg2, res1=res1) assert res1[0] == np.dot(arg1, arg2) + def test_fortran_frontend_transpose(): sources, main = SourceCodeBuilder().add_file(""" subroutine main(arg1, arg2, res1) @@ -85,8 +86,10 @@ def test_fortran_frontend_transpose(): assert np.all(np.transpose(res1) == arg1) + def test_fortran_frontend_transpose_struct(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE @@ -134,6 +137,7 @@ def test_fortran_frontend_transpose_struct(): assert np.all(np.transpose(res1) == arg1) + def test_fortran_frontend_matmul(): sources, main = SourceCodeBuilder().add_file(""" subroutine main(arg1, arg2, res1) @@ -166,6 +170,7 @@ def test_fortran_frontend_matmul(): assert np.all(np.matmul(arg1, arg2) == res1) + if __name__ == "__main__": #test_fortran_frontend_dot() #test_fortran_frontend_dot_range() diff --git a/tests/fortran/intrinsic_bound_test.py b/tests/fortran/intrinsic_bound_test.py index af77aba186..eec0bf47f9 100644 --- a/tests/fortran/intrinsic_bound_test.py +++ b/tests/fortran/intrinsic_bound_test.py @@ -5,7 +5,6 @@ from dace.frontend.fortran import fortran_parser from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string - """ Test the implementation of LBOUND/UBOUND functions. * Standard-sized arrays. @@ -19,6 +18,7 @@ * Arrays inside structures with multiple layers of indirection + assumed size. """ + def test_fortran_frontend_bound(): test_string = """ PROGRAM intrinsic_bound_test @@ -50,6 +50,7 @@ def test_fortran_frontend_bound(): assert np.allclose(res, [1, 1, 4, 7]) + def test_fortran_frontend_bound_offsets(): test_string = """ PROGRAM intrinsic_bound_test @@ -81,8 +82,10 @@ def test_fortran_frontend_bound_offsets(): assert np.allclose(res, [3, 9, 8, 12]) + def test_fortran_frontend_bound_assumed(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE intrinsic_bound_interfaces INTERFACE SUBROUTINE intrinsic_bound_test_function2(input, res) @@ -123,8 +126,10 @@ def test_fortran_frontend_bound_assumed(): assert np.allclose(res, [1, 1, 4, 7]) + def test_fortran_frontend_bound_assumed_offsets(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE intrinsic_bound_interfaces INTERFACE SUBROUTINE intrinsic_bound_test_function2(input, res) @@ -165,8 +170,10 @@ def test_fortran_frontend_bound_assumed_offsets(): assert np.allclose(res, [1, 1, 4, 7]) + def test_fortran_frontend_bound_allocatable_offsets(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE intrinsic_bound_interfaces INTERFACE SUBROUTINE intrinsic_bound_test_function3(input, res) @@ -205,18 +212,18 @@ def test_fortran_frontend_bound_allocatable_offsets(): size = 4 res = np.full([size], 42, order="F", dtype=np.int32) - sdfg( - res=res, - __f2dace_A_input_d_0_s_0=4, - __f2dace_A_input_d_1_s_1=7, - __f2dace_OA_input_d_0_s_0=42, - __f2dace_OA_input_d_1_s_1=13 - ) + sdfg(res=res, + __f2dace_A_input_d_0_s_0=4, + __f2dace_A_input_d_1_s_1=7, + __f2dace_OA_input_d_0_s_0=42, + __f2dace_OA_input_d_1_s_1=13) assert np.allclose(res, [42, 13, 45, 19]) + def test_fortran_frontend_bound_structure(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE TYPE array_container @@ -248,7 +255,9 @@ def test_fortran_frontend_bound_structure(): END SUBROUTINE END MODULE """, 'main').check_with_gfortran().get() - sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg = create_singular_sdfg_from_string(sources, + 'test_bounds.intrinsic_bound_test_function', + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() @@ -258,8 +267,10 @@ def test_fortran_frontend_bound_structure(): assert np.allclose(res, [2, 3, 5, 9]) + def test_fortran_frontend_bound_structure_override(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE TYPE array_container @@ -293,7 +304,9 @@ def test_fortran_frontend_bound_structure_override(): END SUBROUTINE END MODULE """, 'main').check_with_gfortran().get() - sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg = create_singular_sdfg_from_string(sources, + 'test_bounds.intrinsic_bound_test_function', + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() @@ -303,8 +316,10 @@ def test_fortran_frontend_bound_structure_override(): assert np.allclose(res, [2, 3, 5, 9]) + def test_fortran_frontend_bound_structure_recursive(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE @@ -347,7 +362,9 @@ def test_fortran_frontend_bound_structure_recursive(): END SUBROUTINE END MODULE """, 'main').check_with_gfortran().get() - sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg = create_singular_sdfg_from_string(sources, + 'test_bounds.intrinsic_bound_test_function', + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() @@ -357,8 +374,10 @@ def test_fortran_frontend_bound_structure_recursive(): assert np.allclose(res, [-1, 0, 2, 3]) + def test_fortran_frontend_bound_structure_recursive_allocatable(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE @@ -406,7 +425,9 @@ def test_fortran_frontend_bound_structure_recursive_allocatable(): END SUBROUTINE END MODULE """, 'main').check_with_gfortran().get() - sdfg = create_singular_sdfg_from_string(sources, 'test_bounds.intrinsic_bound_test_function', normalize_offsets=True) + sdfg = create_singular_sdfg_from_string(sources, + 'test_bounds.intrinsic_bound_test_function', + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() @@ -416,6 +437,7 @@ def test_fortran_frontend_bound_structure_recursive_allocatable(): assert np.allclose(res, [-1, 0, 2, 3]) + if __name__ == "__main__": test_fortran_frontend_bound() diff --git a/tests/fortran/intrinsic_count_test.py b/tests/fortran/intrinsic_count_test.py index ced135d1a6..8327c529d2 100644 --- a/tests/fortran/intrinsic_count_test.py +++ b/tests/fortran/intrinsic_count_test.py @@ -115,6 +115,7 @@ def test_fortran_frontend_count_array_comparison(): sdfg(first=first, second=second, res=res) assert list(res) == [5, 5, 5, 5, 5, 3, 2] + def test_fortran_frontend_count_array_scalar_comparison(): test_string = """ PROGRAM intrinsic_count_test @@ -166,6 +167,7 @@ def test_fortran_frontend_count_array_scalar_comparison(): sdfg(first=first, res=res) assert list(res) == [1, 1, 0, 0, 1, 1, 4, 2, size - 2] + def test_fortran_frontend_count_array_comparison_wrong_subset(): test_string = """ PROGRAM intrinsic_count_test @@ -189,6 +191,7 @@ def test_fortran_frontend_count_array_comparison_wrong_subset(): with pytest.raises(TypeError): fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test") + def test_fortran_frontend_count_array_2d(): test_string = """ PROGRAM intrinsic_count_test @@ -217,7 +220,7 @@ def test_fortran_frontend_count_array_2d(): sdfg(d=d, res=res) assert res[0] == 35 - d[2,2] = False + d[2, 2] = False sdfg(d=d, res=res) assert res[0] == 34 @@ -225,10 +228,11 @@ def test_fortran_frontend_count_array_2d(): sdfg(d=d, res=res) assert res[0] == 0 - d[2,2] = True + d[2, 2] = True sdfg(d=d, res=res) assert res[0] == 1 + def test_fortran_frontend_count_array_comparison_2d(): test_string = """ PROGRAM intrinsic_count_test @@ -274,6 +278,7 @@ def test_fortran_frontend_count_array_comparison_2d(): sdfg(first=first, second=second, res=res) assert list(res) == [20, 20, 20, 20, 20, 20, 4] + def test_fortran_frontend_count_array_comparison_2d_subset(): test_string = """ PROGRAM intrinsic_count_test @@ -316,6 +321,7 @@ def test_fortran_frontend_count_array_comparison_2d_subset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 4] + def test_fortran_frontend_count_array_comparison_2d_subset_offset(): test_string = """ PROGRAM intrinsic_count_test @@ -358,6 +364,7 @@ def test_fortran_frontend_count_array_comparison_2d_subset_offset(): sdfg(first=first, second=second, res=res) assert list(res) == [0, 4] + if __name__ == "__main__": test_fortran_frontend_count_array() diff --git a/tests/fortran/intrinsic_math_test.py b/tests/fortran/intrinsic_math_test.py index d407ba3bac..8041aeba1a 100644 --- a/tests/fortran/intrinsic_math_test.py +++ b/tests/fortran/intrinsic_math_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_min_max(): test_string = """ PROGRAM intrinsic_math_test_min_max @@ -88,6 +89,7 @@ def test_fortran_frontend_sqrt(): for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 + def test_fortran_frontend_sqrt_structure(): test_string = """ module lib @@ -153,6 +155,7 @@ def test_fortran_frontend_sqrt_structure(): for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 + def test_fortran_frontend_abs(): test_string = """ PROGRAM intrinsic_math_test_abs @@ -186,6 +189,7 @@ def test_fortran_frontend_abs(): assert res[0] == 30 assert res[1] == 40 + def test_fortran_frontend_exp(): test_string = """ PROGRAM intrinsic_math_test_exp @@ -220,6 +224,7 @@ def test_fortran_frontend_exp(): for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 + def test_fortran_frontend_log(): test_string = """ PROGRAM intrinsic_math_test_log @@ -254,6 +259,7 @@ def test_fortran_frontend_log(): for f_res, p_res in zip(res, py_res): assert abs(f_res - p_res) < 10**-9 + def test_fortran_frontend_mod_float(): test_string = """ PROGRAM intrinsic_math_test_mod @@ -305,6 +311,7 @@ def test_fortran_frontend_mod_float(): assert res[4] == 1 assert res[5] == -1 + def test_fortran_frontend_mod_integer(): test_string = """ PROGRAM intrinsic_math_test_mod @@ -347,6 +354,7 @@ def test_fortran_frontend_mod_integer(): assert res[2] == 2 assert res[3] == -2 + def test_fortran_frontend_modulo_float(): test_string = """ PROGRAM intrinsic_math_test_modulo @@ -398,6 +406,7 @@ def test_fortran_frontend_modulo_float(): assert res[4] == 1.0 assert res[5] == 4.5 + def test_fortran_frontend_modulo_integer(): test_string = """ PROGRAM intrinsic_math_test_modulo @@ -441,6 +450,7 @@ def test_fortran_frontend_modulo_integer(): assert res[2] == -1 assert res[3] == -2 + def test_fortran_frontend_floor(): test_string = """ PROGRAM intrinsic_math_test_floor @@ -480,6 +490,7 @@ def test_fortran_frontend_floor(): assert res[2] == -4 assert res[3] == -64 + def test_fortran_frontend_scale(): test_string = """ PROGRAM intrinsic_math_test_scale @@ -529,6 +540,7 @@ def test_fortran_frontend_scale(): assert res[3] == 65280. assert res[4] == 11141120. + def test_fortran_frontend_exponent(): test_string = """ PROGRAM intrinsic_math_test_exponent @@ -568,6 +580,7 @@ def test_fortran_frontend_exponent(): assert res[2] == 4 assert res[3] == 9 + def test_fortran_frontend_int(): test_string = """ PROGRAM intrinsic_math_test_int @@ -624,7 +637,7 @@ def test_fortran_frontend_int(): d[1] = 1.5 d[2] = 42.5 d[3] = -42.5 - d2 = np.full([size*2], 42, order="F", dtype=np.float32) + d2 = np.full([size * 2], 42, order="F", dtype=np.float32) d2[0] = 3.49 d2[1] = 3.5 d2[2] = 3.51 @@ -647,6 +660,7 @@ def test_fortran_frontend_int(): assert np.array_equal(res4, [3., 4., 4., 4., -3., -4., -4., -4.]) + def test_fortran_frontend_real(): test_string = """ PROGRAM intrinsic_math_test_real @@ -699,13 +713,14 @@ def test_fortran_frontend_real(): d3[0] = 7 d3[1] = 13 - res = np.full([size*3], 42, order="F", dtype=np.float64) - res2 = np.full([size*3], 42, order="F", dtype=np.float32) + res = np.full([size * 3], 42, order="F", dtype=np.float64) + res2 = np.full([size * 3], 42, order="F", dtype=np.float32) sdfg(d=d, d2=d2, d3=d3, res=res, res2=res2) assert np.allclose(res, [7.0, 13.11, 7.0, 13.11, 7., 13.]) assert np.allclose(res2, [7.0, 13.11, 7.0, 13.11, 7., 13.]) + def test_fortran_frontend_trig(): test_string = """ PROGRAM intrinsic_math_test_trig @@ -738,14 +753,15 @@ def test_fortran_frontend_trig(): size = 3 d = np.full([size], 42, order="F", dtype=np.float32) d[0] = 0 - d[1] = 3.14/2 + d[1] = 3.14 / 2 d[2] = 3.14 - res = np.full([size*2], 42, order="F", dtype=np.float32) + res = np.full([size * 2], 42, order="F", dtype=np.float32) sdfg(d=d, res=res) assert np.allclose(res, [0.0, 0.999999702, 1.59254798E-03, 1.0, 7.96274282E-04, -0.999998748]) + def test_fortran_frontend_hyperbolic(): test_string = """ PROGRAM intrinsic_math_test_hyperbolic @@ -785,10 +801,13 @@ def test_fortran_frontend_hyperbolic(): d[1] = 1 d[2] = 3.14 - res = np.full([size*3], 42, order="F", dtype=np.float32) + res = np.full([size * 3], 42, order="F", dtype=np.float32) sdfg(d=d, res=res) - assert np.allclose(res, [0.00000000, 1.17520118, 11.5302935, 1.00000000, 1.54308057, 11.5735760, 0.00000000, 0.761594176, 0.996260226]) + assert np.allclose( + res, + [0.00000000, 1.17520118, 11.5302935, 1.00000000, 1.54308057, 11.5735760, 0.00000000, 0.761594176, 0.996260226]) + def test_fortran_frontend_trig_inverse(): test_string = """ @@ -842,7 +861,7 @@ def test_fortran_frontend_trig_inverse(): atan_args[1] = 1.0 atan_args[2] = 3.14 - atan2_args = np.full([size*2], 42, order="F", dtype=np.float32) + atan2_args = np.full([size * 2], 42, order="F", dtype=np.float32) atan2_args[0] = 0.0 atan2_args[1] = 1.0 atan2_args[2] = 1.0 @@ -850,10 +869,14 @@ def test_fortran_frontend_trig_inverse(): atan2_args[4] = 1.0 atan2_args[5] = 0.0 - res = np.full([size*4], 42, order="F", dtype=np.float32) + res = np.full([size * 4], 42, order="F", dtype=np.float32) sdfg(sincos_args=sincos_args, tan_args=atan_args, tan2_args=atan2_args, res=res) - assert np.allclose(res, [-0.523598790, 0.00000000, 1.57079637, 2.09439516, 1.57079637, 0.00000000, 0.00000000, 0.785398185, 1.26248074, 0.00000000, 0.785398185, 1.57079637]) + assert np.allclose(res, [ + -0.523598790, 0.00000000, 1.57079637, 2.09439516, 1.57079637, 0.00000000, 0.00000000, 0.785398185, 1.26248074, + 0.00000000, 0.785398185, 1.57079637 + ]) + if __name__ == "__main__": diff --git a/tests/fortran/intrinsic_merge_test.py b/tests/fortran/intrinsic_merge_test.py index 95421d843e..933b889211 100644 --- a/tests/fortran/intrinsic_merge_test.py +++ b/tests/fortran/intrinsic_merge_test.py @@ -4,6 +4,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser + def test_fortran_frontend_merge_1d(): """ Tests that the generated array map correctly handles offsets. @@ -45,12 +46,12 @@ def test_fortran_frontend_merge_1d(): for val in res: assert val == 42 - for i in range(int(size/2)): + for i in range(int(size / 2)): mask[i] = 1 sdfg(input1=first, input2=second, mask=mask, res=res) - for i in range(int(size/2)): + for i in range(int(size / 2)): assert res[i] == 13 - for i in range(int(size/2), size): + for i in range(int(size / 2), size): assert res[i] == 42 mask[:] = 0 @@ -64,6 +65,7 @@ def test_fortran_frontend_merge_1d(): else: assert res[i] == 42 + def test_fortran_frontend_merge_comparison_scalar(): """ Tests that the generated array map correctly handles offsets. @@ -102,12 +104,12 @@ def test_fortran_frontend_merge_comparison_scalar(): for val in res: assert val == 42 - for i in range(int(size/2)): + for i in range(int(size / 2)): first[i] = 3 sdfg(input1=first, input2=second, res=res) - for i in range(int(size/2)): + for i in range(int(size / 2)): assert res[i] == 3 - for i in range(int(size/2), size): + for i in range(int(size / 2), size): assert res[i] == 42 first[:] = 13 @@ -121,6 +123,7 @@ def test_fortran_frontend_merge_comparison_scalar(): else: assert res[i] == 42 + def test_fortran_frontend_merge_comparison_arrays(): """ Tests that the generated array map correctly handles offsets. @@ -159,12 +162,12 @@ def test_fortran_frontend_merge_comparison_arrays(): for val in res: assert val == 13 - for i in range(int(size/2)): + for i in range(int(size / 2)): first[i] = 45 sdfg(input1=first, input2=second, res=res) - for i in range(int(size/2)): + for i in range(int(size / 2)): assert res[i] == 42 - for i in range(int(size/2), size): + for i in range(int(size / 2), size): assert res[i] == 13 first[:] = 13 @@ -215,8 +218,8 @@ def test_fortran_frontend_merge_comparison_arrays_offset(): # Minimum is in the beginning first = np.full([size], 13, order="F", dtype=np.float64) second = np.full([size], 42, order="F", dtype=np.float64) - mask1 = np.full([size*2], 30, order="F", dtype=np.float64) - mask2 = np.full([size*2], 0, order="F", dtype=np.float64) + mask1 = np.full([size * 2], 30, order="F", dtype=np.float64) + mask2 = np.full([size * 2], 0, order="F", dtype=np.float64) res = np.full([size], 40, order="F", dtype=np.float64) mask1[2:9] = 3 @@ -261,9 +264,9 @@ def test_fortran_frontend_merge_array_shift(): # Minimum is in the beginning first = np.full([size], 13, order="F", dtype=np.float64) - second = np.full([size*3], 42, order="F", dtype=np.float64) - mask1 = np.full([size*2], 30, order="F", dtype=np.float64) - mask2 = np.full([size*2], 0, order="F", dtype=np.float64) + second = np.full([size * 3], 42, order="F", dtype=np.float64) + mask1 = np.full([size * 2], 30, order="F", dtype=np.float64) + mask2 = np.full([size * 2], 0, order="F", dtype=np.float64) res = np.full([size], 40, order="F", dtype=np.float64) second[12:19] = 100 @@ -273,6 +276,7 @@ def test_fortran_frontend_merge_array_shift(): for val in res: assert val == 100 + def test_fortran_frontend_merge_nonarray(): """ Tests that the generated array map correctly handles offsets. @@ -314,6 +318,7 @@ def test_fortran_frontend_merge_nonarray(): sdfg(val=val, res=res) assert res[0] == 5 + def test_fortran_frontend_merge_recursive(): """ Tests that the generated array map correctly handles offsets. @@ -357,7 +362,7 @@ def test_fortran_frontend_merge_recursive(): mask2 = np.full([size], 1, order="F", dtype=np.int32) res = np.full([size], 40, order="F", dtype=np.float64) - for i in range(int(size/2)): + for i in range(int(size / 2)): mask1[i] = 1 mask2[-1] = 0 @@ -366,6 +371,7 @@ def test_fortran_frontend_merge_recursive(): assert np.allclose(res, [13, 13, 13, 42, 42, 42, 43]) + def test_fortran_frontend_merge_scalar(): """ Tests that the generated array map correctly handles offsets. @@ -461,6 +467,7 @@ def test_fortran_frontend_merge_scalar2(): sdfg(input1=first, input2=second, mask=mask, res=res) assert res[0] == 13 + def test_fortran_frontend_merge_scalar3(): """ Tests that the generated array map correctly handles offsets. @@ -508,6 +515,7 @@ def test_fortran_frontend_merge_scalar3(): sdfg(input1=first, input2=second, mask=mask, mask2=mask2, res=res) assert res[0] == 13 + if __name__ == "__main__": test_fortran_frontend_merge_1d() diff --git a/tests/fortran/intrinsic_minmaxval_test.py b/tests/fortran/intrinsic_minmaxval_test.py index 5c0cb2cca6..ca01216c20 100644 --- a/tests/fortran/intrinsic_minmaxval_test.py +++ b/tests/fortran/intrinsic_minmaxval_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string + def test_fortran_frontend_minval_double(): """ Tests that the generated array map correctly handles offsets. @@ -59,6 +60,7 @@ def test_fortran_frontend_minval_double(): # It should be the dace max for integer assert res[3] == np.finfo(np.float64).max + def test_fortran_frontend_minval_int(): """ Tests that the generated array map correctly handles offsets. @@ -125,6 +127,7 @@ def test_fortran_frontend_minval_int(): # It should be the dace max for integer assert res[3] == np.iinfo(np.int32).max + def test_fortran_frontend_maxval_double(): """ Tests that the generated array map correctly handles offsets. @@ -179,6 +182,7 @@ def test_fortran_frontend_maxval_double(): # It should be the dace max for integer assert res[3] == np.finfo(np.float64).min + def test_fortran_frontend_maxval_int(): """ Tests that the generated array map correctly handles offsets. @@ -245,8 +249,10 @@ def test_fortran_frontend_maxval_int(): # It should be the dace max for integer assert res[3] == np.iinfo(np.int32).min + def test_fortran_frontend_minval_struct(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE test_types IMPLICIT NONE TYPE array_container @@ -293,6 +299,7 @@ def test_fortran_frontend_minval_struct(): sdfg(d=d, res=res) print(res) + if __name__ == "__main__": #test_fortran_frontend_minval_double() diff --git a/tests/fortran/intrinsic_product_test.py b/tests/fortran/intrinsic_product_test.py index fcf9dc8057..095eb4485f 100644 --- a/tests/fortran/intrinsic_product_test.py +++ b/tests/fortran/intrinsic_product_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser + def test_fortran_frontend_product_array(): """ Tests that the generated array map correctly handles offsets. @@ -44,6 +45,7 @@ def test_fortran_frontend_product_array(): assert res[1] == np.prod(d) assert res[2] == np.prod(d[1:5]) + def test_fortran_frontend_product_array_dim(): test_string = """ PROGRAM intrinsic_count_test @@ -65,6 +67,7 @@ def test_fortran_frontend_product_array_dim(): with pytest.raises(NotImplementedError): fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + def test_fortran_frontend_product_2d(): """ Tests that the generated array map correctly handles offsets. @@ -109,6 +112,7 @@ def test_fortran_frontend_product_2d(): assert res[2] == np.prod(d[1:4, 1]) assert res[3] == np.prod(d[1:4, 1:3]) + if __name__ == "__main__": test_fortran_frontend_product_array() diff --git a/tests/fortran/intrinsic_sum_test.py b/tests/fortran/intrinsic_sum_test.py index e933589e0f..a497271ff6 100644 --- a/tests/fortran/intrinsic_sum_test.py +++ b/tests/fortran/intrinsic_sum_test.py @@ -4,6 +4,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser + def test_fortran_frontend_sum2loop_1d_without_offset(): """ Tests that the generated array map correctly handles offsets. @@ -41,7 +42,8 @@ def test_fortran_frontend_sum2loop_1d_without_offset(): sdfg(d=d, res=res) assert res[0] == (1 + size) * size / 2 assert res[1] == (1 + size) * size / 2 - assert res[2] == (2 + size - 1) * (size - 2)/ 2 + assert res[2] == (2 + size - 1) * (size - 2) / 2 + def test_fortran_frontend_sum2loop_1d_offset(): """ @@ -82,6 +84,7 @@ def test_fortran_frontend_sum2loop_1d_offset(): assert res[1] == (1 + size) * size / 2 assert res[2] == (2 + size - 1) * (size - 2) / 2 + def test_fortran_frontend_arr2loop_2d(): """ Tests that the generated array map correctly handles offsets. @@ -126,6 +129,7 @@ def test_fortran_frontend_arr2loop_2d(): assert res[2] == 21 assert res[3] == 45 + def test_fortran_frontend_arr2loop_2d_offset(): """ Tests that the generated array map correctly handles offsets. @@ -168,6 +172,7 @@ def test_fortran_frontend_arr2loop_2d_offset(): assert res[1] == 190 assert res[2] == 57 + if __name__ == "__main__": test_fortran_frontend_sum2loop_1d_without_offset() diff --git a/tests/fortran/long_tasklet_test.py b/tests/fortran/long_tasklet_test.py index eb59a4bea6..92d16fc9df 100644 --- a/tests/fortran/long_tasklet_test.py +++ b/tests/fortran/long_tasklet_test.py @@ -6,6 +6,7 @@ import dace.frontend.fortran.ast_internal_classes as ast_internal_classes import numpy as np + def test_fortran_frontend_long_tasklet(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -40,15 +41,16 @@ def test_fortran_frontend_long_tasklet(): END SUBROUTINE long_tasklet_test_function """ - sources={} - sources["long_tasklet_test"]=test_string + sources = {} + sources["long_tasklet_test"] = test_string sdfg = fortran_parser.create_sdfg_from_string(test_string, "long_tasklet_test", True, sources=sources) sdfg.simplify(verbose=True) a = np.full([5], 42, order="F", dtype=np.float64) sdfg(d=a) assert (a[1] == 5.5) assert (a[0] == 4) - + + if __name__ == "__main__": test_fortran_frontend_long_tasklet() diff --git a/tests/fortran/missing_func_test.py b/tests/fortran/missing_func_test.py index 1b55dd324d..6ec9bb4cfa 100644 --- a/tests/fortran/missing_func_test.py +++ b/tests/fortran/missing_func_test.py @@ -92,8 +92,8 @@ def test_fortran_frontend_missing_func(): END SUBROUTINE acc_wait_if_requested """ - sources={} - sources["missing_test"]=test_string + sources = {} + sources["missing_test"] = test_string sdfg = fortran_parser.create_sdfg_from_string(test_string, "missing_test", True, sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) @@ -102,6 +102,7 @@ def test_fortran_frontend_missing_func(): assert (a[1, 0] == 6.5) assert (a[2, 0] == 42) + def test_fortran_frontend_missing_extraction(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -130,9 +131,9 @@ def test_fortran_frontend_missing_extraction(): END SUBROUTINE missing_extraction_test_function """ - sources={} - sources["missing_extraction_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "missing_extraction_test",sources=sources) + sources = {} + sources["missing_extraction_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "missing_extraction_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -140,7 +141,7 @@ def test_fortran_frontend_missing_extraction(): assert (a[1, 0] == 5.5) assert (a[2, 0] == 42) + if __name__ == "__main__": test_fortran_frontend_missing_func() test_fortran_frontend_missing_extraction() - \ No newline at end of file diff --git a/tests/fortran/nested_array_test.py b/tests/fortran/nested_array_test.py index 6c4cb14535..b7805753bf 100644 --- a/tests/fortran/nested_array_test.py +++ b/tests/fortran/nested_array_test.py @@ -42,8 +42,12 @@ def test_fortran_frontend_nested_array_access(): END SUBROUTINE nested_array_access_test_function """ - sources={"nested_array_access_test_function": test_string} - sdfg = fortran_parser.create_sdfg_from_string(test_string, "nested_array_access_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sources = {"nested_array_access_test_function": test_string} + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "nested_array_access_test", + normalize_offsets=True, + multiple_sdfgs=False, + sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -86,16 +90,20 @@ def test_fortran_frontend_nested_array_access2(): END SUBROUTINE nested_array_access2_test_function """ - sources={"nested_array_access2_test_function": test_string} - sdfg = fortran_parser.create_sdfg_from_string(test_string, "nested_array_access2_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources) + sources = {"nested_array_access2_test_function": test_string} + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "nested_array_access2_test", + normalize_offsets=True, + multiple_sdfgs=False, + sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) assert (a[0] == 42) assert (a[1] == 5.5) - assert (a[2] == 42) + assert (a[2] == 42) + if __name__ == "__main__": test_fortran_frontend_nested_array_access2() - diff --git a/tests/fortran/non-interactive/fortran_int_init_test.py b/tests/fortran/non-interactive/fortran_int_init_test.py index 7632db6d19..791337b934 100644 --- a/tests/fortran/non-interactive/fortran_int_init_test.py +++ b/tests/fortran/non-interactive/fortran_int_init_test.py @@ -7,7 +7,6 @@ import numpy as np import pytest - from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser from fparser.two.symbol_table import SymbolTable @@ -36,7 +35,7 @@ def test_fortran_frontend_int_init(): d(1)=INT(z'000000ffffffffff',i8) END SUBROUTINE int_init_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "int_init_test",False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "int_init_test", False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -46,7 +45,7 @@ def test_fortran_frontend_int_init(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() sdfg.compile() @@ -57,10 +56,6 @@ def test_fortran_frontend_int_init(): # assert (d[0] == 400) - if __name__ == "__main__": - - test_fortran_frontend_int_init() - diff --git a/tests/fortran/non-interactive/function_test.py b/tests/fortran/non-interactive/function_test.py index ec95555c8f..87cfd260c3 100644 --- a/tests/fortran/non-interactive/function_test.py +++ b/tests/fortran/non-interactive/function_test.py @@ -76,7 +76,7 @@ def test_fortran_frontend_function_test(): END FUNCTION function_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -86,7 +86,7 @@ def test_fortran_frontend_function_test(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() @@ -136,7 +136,7 @@ def test_fortran_frontend_function_test2(): """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -146,12 +146,11 @@ def test_fortran_frontend_function_test2(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() - def test_fortran_frontend_function_test3(): """ Tests to check whether Fortran array slices are correctly translates to DaCe views. @@ -256,7 +255,7 @@ def test_fortran_frontend_function_test3(): END FUNCTION function3_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -266,13 +265,12 @@ def test_fortran_frontend_function_test3(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() sdfg.compile() - @pytest.mark.skip(reason="Interactive test (opens SDFG).") def test_fortran_frontend_function_test4(): """ @@ -323,7 +321,7 @@ def test_fortran_frontend_function_test4(): """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -333,13 +331,12 @@ def test_fortran_frontend_function_test4(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() sdfg.compile() - @pytest.mark.skip(reason="Interactive test (opens SDFG).") def test_fortran_frontend_function_test5(): """ @@ -383,7 +380,7 @@ def test_fortran_frontend_function_test5(): """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -393,11 +390,12 @@ def test_fortran_frontend_function_test5(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() + sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) sdfg.view() sdfg.compile() + if __name__ == "__main__": #test_fortran_frontend_function_test() diff --git a/tests/fortran/non-interactive/pointers_test.py b/tests/fortran/non-interactive/pointers_test.py index 3b98595ab9..85ff922e02 100644 --- a/tests/fortran/non-interactive/pointers_test.py +++ b/tests/fortran/non-interactive/pointers_test.py @@ -58,7 +58,7 @@ def test_fortran_frontend_pointer_test(): END SUBROUTINE pointer_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -68,14 +68,12 @@ def test_fortran_frontend_pointer_test(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() - sdfg.validate() + sdfg.reset_sdfg_list() + sdfg.validate() sdfg.simplify(verbose=True) sdfg.view() - - if __name__ == "__main__": test_fortran_frontend_pointer_test() diff --git a/tests/fortran/non-interactive/view_test.py b/tests/fortran/non-interactive/view_test.py index eea4ca1c90..241d93aafb 100644 --- a/tests/fortran/non-interactive/view_test.py +++ b/tests/fortran/non-interactive/view_test.py @@ -63,34 +63,34 @@ def test_fortran_frontend_view_test(): END SUBROUTINE viewlens """ - sdfg2 = fortran_parser.create_sdfg_from_string(test_string, test_name,False,False) + sdfg2 = fortran_parser.create_sdfg_from_string(test_string, test_name, False, False) sdfg2.view() - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, True) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): - if node.path!="": - print("TEST: "+node.path) + if node.path != "": + print("TEST: " + node.path) tmp_sdfg = SDFG.from_file(node.path) node.sdfg = tmp_sdfg node.sdfg.parent = state node.sdfg.parent_sdfg = sdfg node.sdfg.update_sdfg_list([]) node.sdfg.parent_nsdfg_node = node - node.path="" - for sd in sdfg.all_sdfgs_recursive(): + node.path = "" + for sd in sdfg.all_sdfgs_recursive(): for state in sd.nodes(): - for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): - if node.path!="": - print("TEST: "+node.path) - tmp_sdfg = SDFG.from_file(node.path) - node.sdfg = tmp_sdfg - node.sdfg.parent = state - node.sdfg.parent_sdfg = sd - node.sdfg.update_sdfg_list([]) - node.sdfg.parent_nsdfg_node = node - node.path="" + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + if node.path != "": + print("TEST: " + node.path) + tmp_sdfg = SDFG.from_file(node.path) + node.sdfg = tmp_sdfg + node.sdfg.parent = state + node.sdfg.parent_sdfg = sd + node.sdfg.update_sdfg_list([]) + node.sdfg.parent_nsdfg_node = node + node.path = "" for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): if node.sdfg is not None: @@ -100,8 +100,8 @@ def test_fortran_frontend_view_test(): sdfg.parent = None sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None - sdfg.reset_sdfg_list() - sdfg.view() + sdfg.reset_sdfg_list() + sdfg.view() sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([1, 1, 2], 42, order="F", dtype=np.float64) @@ -157,7 +157,7 @@ def test_fortran_frontend_view_test_2(): END SUBROUTINE viewlens """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, True) #sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([10, 11, 12], 42, order="F", dtype=np.float64) @@ -211,7 +211,7 @@ def test_fortran_frontend_view_test_3(): END SUBROUTINE viewlens """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name,False,True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, False, True) #sdfg.simplify(verbose=True) a = np.full([10, 11, 12], 42, order="F", dtype=np.float64) b = np.full([10, 11, 12], 42, order="F", dtype=np.float64) diff --git a/tests/fortran/offset_normalizer_test.py b/tests/fortran/offset_normalizer_test.py index 3f8d9fc5c2..f012300e57 100644 --- a/tests/fortran/offset_normalizer_test.py +++ b/tests/fortran/offset_normalizer_test.py @@ -4,6 +4,7 @@ from dace.frontend.fortran import ast_internal_classes, ast_transforms, fortran_parser + def test_fortran_frontend_offset_normalizer_1d(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -45,8 +46,9 @@ def test_fortran_frontend_offset_normalizer_1d(): a = np.full([5], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(0,5): - assert a[i] == (50+i)* 2 + for i in range(0, 5): + assert a[i] == (50 + i) * 2 + def test_fortran_frontend_offset_normalizer_1d_symbol(): """ @@ -90,17 +92,18 @@ def test_fortran_frontend_offset_normalizer_1d_symbol(): sdfg.compile() from dace.symbolic import evaluate - arrsize=50 - arrsize2=54 + arrsize = 50 + arrsize2 = 54 assert len(sdfg.data('d').shape) == 1 assert evaluate(sdfg.data('d').shape[0], {'arrsize': arrsize, 'arrsize2': arrsize2}) == 5 - arrsize=50 - arrsize2=54 - a = np.full([arrsize2-arrsize+1], 42, order="F", dtype=np.float64) + arrsize = 50 + arrsize2 = 54 + a = np.full([arrsize2 - arrsize + 1], 42, order="F", dtype=np.float64) sdfg(d=a, arrsize=arrsize, arrsize2=arrsize2) for i in range(0, arrsize2 - arrsize + 1): - assert a[i] == (50+i)* 2 + assert a[i] == (50 + i) * 2 + def test_fortran_frontend_offset_normalizer_2d(): """ @@ -151,11 +154,12 @@ def test_fortran_frontend_offset_normalizer_2d(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(0,5): - for j in range(0,3): - assert a[i, j] == (50+i) * 2 + 3 * (7 + j) + for i in range(0, 5): + for j in range(0, 3): + assert a[i, j] == (50 + i) * 2 + 3 * (7 + j) + def test_fortran_frontend_offset_normalizer_2d_symbol(): """ @@ -213,21 +217,17 @@ def test_fortran_frontend_offset_normalizer_2d_symbol(): sdfg.compile() from dace.symbolic import evaluate - values = { - 'arrsize': 50, - 'arrsize2': 54, - 'arrsize3': 7, - 'arrsize4': 9 - } + values = {'arrsize': 50, 'arrsize2': 54, 'arrsize3': 7, 'arrsize4': 9} assert len(sdfg.data('d').shape) == 2 assert evaluate(sdfg.data('d').shape[0], values) == 5 assert evaluate(sdfg.data('d').shape[1], values) == 3 - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a, **values) - for i in range(0,5): - for j in range(0,3): - assert a[i, j] == (50+i) * 2 + 3 * (7 + j) + for i in range(0, 5): + for j in range(0, 3): + assert a[i, j] == (50 + i) * 2 + 3 * (7 + j) + def test_fortran_frontend_offset_normalizer_2d_arr2loop(): """ @@ -277,12 +277,13 @@ def test_fortran_frontend_offset_normalizer_2d_arr2loop(): assert sdfg.data('d').shape[0] == 5 assert sdfg.data('d').shape[1] == 3 - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a) - for i in range(0,5): - for j in range(0,3): + for i in range(0, 5): + for j in range(0, 3): assert a[i, j] == (50 + i) * 2 + def test_fortran_frontend_offset_normalizer_2d_arr2loop_symbol(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. @@ -337,22 +338,18 @@ def test_fortran_frontend_offset_normalizer_2d_arr2loop_symbol(): sdfg.compile() from dace.symbolic import evaluate - values = { - 'arrsize': 50, - 'arrsize2': 54, - 'arrsize3': 7, - 'arrsize4': 9 - } + values = {'arrsize': 50, 'arrsize2': 54, 'arrsize3': 7, 'arrsize4': 9} assert len(sdfg.data('d').shape) == 2 assert evaluate(sdfg.data('d').shape[0], values) == 5 assert evaluate(sdfg.data('d').shape[1], values) == 3 - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a, **values) - for i in range(0,5): - for j in range(0,3): + for i in range(0, 5): + for j in range(0, 3): assert a[i, j] == (50 + i) * 2 + def test_fortran_frontend_offset_normalizer_struct(): test_string = """ PROGRAM index_offset_test @@ -402,22 +399,18 @@ def test_fortran_frontend_offset_normalizer_struct(): sdfg.compile() from dace.symbolic import evaluate - values = { - 'arrsize': 50, - 'arrsize2': 54, - 'arrsize3': 7, - 'arrsize4': 9 - } + values = {'arrsize': 50, 'arrsize2': 54, 'arrsize3': 7, 'arrsize4': 9} assert len(sdfg.data('d').shape) == 2 assert evaluate(sdfg.data('d').shape[0], values) == 5 assert evaluate(sdfg.data('d').shape[1], values) == 3 - a = np.full([5,3], 42, order="F", dtype=np.float64) + a = np.full([5, 3], 42, order="F", dtype=np.float64) sdfg(d=a, **values) - for i in range(0,5): - for j in range(0,3): + for i in range(0, 5): + for j in range(0, 3): assert a[i, j] == (50 + i) * 2 + if __name__ == "__main__": #test_fortran_frontend_offset_normalizer_1d() diff --git a/tests/fortran/optional_args_test.py b/tests/fortran/optional_args_test.py index 11abc69994..610dc9561e 100644 --- a/tests/fortran/optional_args_test.py +++ b/tests/fortran/optional_args_test.py @@ -6,9 +6,11 @@ from dace.frontend.fortran import fortran_parser from tests.fortran.fortran_test_helper import create_singular_sdfg_from_string, SourceCodeBuilder + def test_fortran_frontend_optional(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE intrinsic_optional_test INTERFACE @@ -51,9 +53,11 @@ def test_fortran_frontend_optional(): assert res[0] == 5 assert res2[0] == 0 + def test_fortran_frontend_optional_complex(): - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ MODULE intrinsic_optional_test INTERFACE diff --git a/tests/fortran/parent_test.py b/tests/fortran/parent_test.py index 1f66c81311..0000d188f0 100644 --- a/tests/fortran/parent_test.py +++ b/tests/fortran/parent_test.py @@ -10,7 +10,8 @@ def test_fortran_frontend_parent(): """ Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ program main implicit none double precision d(4) @@ -58,7 +59,8 @@ def test_fortran_frontend_module(): ! good enough approximation integer, parameter :: pi = 4 end module lib -""").add_file(""" +""").add_file( + """ program main implicit none double precision d(4) diff --git a/tests/fortran/pointer_removal_test.py b/tests/fortran/pointer_removal_test.py index 3e2705c4b3..d359b4fcd9 100644 --- a/tests/fortran/pointer_removal_test.py +++ b/tests/fortran/pointer_removal_test.py @@ -20,6 +20,7 @@ from dace.transformation.passes.lift_struct_views import LiftStructViews from dace.transformation import pass_pipeline as ppl + def test_fortran_frontend_ptr_assignment_removal(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -48,9 +49,9 @@ def test_fortran_frontend_ptr_assignment_removal(): d(2,1) = max(1.0, tmp) END SUBROUTINE type_in_call_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -87,9 +88,12 @@ def test_fortran_frontend_ptr_assignment_removal_array(): d(2,1) = max(1.0, tmp(1,1,1)) END SUBROUTINE type_in_call_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources,normalize_offsets=True) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_in_call_test", + sources=sources, + normalize_offsets=True) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -97,6 +101,7 @@ def test_fortran_frontend_ptr_assignment_removal_array(): assert (a[1, 0] == 11) assert (a[2, 0] == 42) + def test_fortran_frontend_ptr_assignment_removal_array_assumed(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -135,9 +140,9 @@ def test_fortran_frontend_ptr_assignment_removal_array_assumed(): tmp(2,1,1) = 1410 END SUBROUTINE type_in_call_test_function2 """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -146,6 +151,7 @@ def test_fortran_frontend_ptr_assignment_removal_array_assumed(): assert (a[1, 0] == 11) assert (a[2, 0] == 1410) + def test_fortran_frontend_ptr_assignment_removal_array_nested(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -188,9 +194,9 @@ def test_fortran_frontend_ptr_assignment_removal_array_nested(): d(2,1) = tmp(1,1,1) END SUBROUTINE type_in_call_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -198,6 +204,7 @@ def test_fortran_frontend_ptr_assignment_removal_array_nested(): assert (a[1, 0] == 11) assert (a[2, 0] == 42) + if __name__ == "__main__": # pointers to non-array fields are broken #test_fortran_frontend_ptr_assignment_removal() diff --git a/tests/fortran/prune_test.py b/tests/fortran/prune_test.py index 46585d2825..84cf4a7168 100644 --- a/tests/fortran/prune_test.py +++ b/tests/fortran/prune_test.py @@ -18,6 +18,7 @@ import dace.frontend.fortran.ast_utils as ast_utils import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + def test_fortran_frontend_prune_simple(): test_string = """ PROGRAM init_test @@ -42,7 +43,7 @@ def test_fortran_frontend_prune_simple(): sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) print(a) - sdfg(d=a,outside_init=0) + sdfg(d=a, outside_init=0) print(a) assert (a[0] == 42) assert (a[1] == 42 + 3.14) @@ -90,12 +91,13 @@ def test_fortran_frontend_prune_complex(): sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) print(a) - sdfg(d=a,outside_init=0) + sdfg(d=a, outside_init=0) print(a) assert (a[0] == 42) assert (a[1] == 42 + 3.14) assert (a[2] == 40) + def test_fortran_frontend_prune_actual_param(): # Test we do not remove a variable that is passed along # but not used in the function. @@ -134,12 +136,13 @@ def test_fortran_frontend_prune_actual_param(): sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) print(a) - sdfg(d=a,outside_init=0) + sdfg(d=a, outside_init=0) print(a) assert (a[0] == 42) assert (a[1] == 42) assert (a[2] == 40) + if __name__ == "__main__": test_fortran_frontend_prune_simple() diff --git a/tests/fortran/prune_unused_children_test.py b/tests/fortran/prune_unused_children_test.py index 50d69628ed..cc51ee218d 100644 --- a/tests/fortran/prune_unused_children_test.py +++ b/tests/fortran/prune_unused_children_test.py @@ -6,9 +6,7 @@ from fparser.two.parser import ParserFactory from fparser.two.utils import walk -from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, - ENTRY_POINT_OBJECT_TYPES, - find_name_of_node, +from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, ENTRY_POINT_OBJECT_TYPES, find_name_of_node, prune_unused_objects) from dace.frontend.fortran.fortran_parser import recursive_ast_improver from tests.fortran.fortran_test_helper import SourceCodeBuilder diff --git a/tests/fortran/ranges_test.py b/tests/fortran/ranges_test.py index 39363f7412..b4fcbb0ed4 100644 --- a/tests/fortran/ranges_test.py +++ b/tests/fortran/ranges_test.py @@ -3,7 +3,6 @@ import numpy as np from dace.frontend.fortran import ast_transforms, fortran_parser - """ We test for the following patterns: * Range 'ALL' @@ -18,6 +17,7 @@ * Assignment with arrays that have no range expression on the right """ + def test_fortran_frontend_multiple_ranges_all(): """ Tests that the generated array map correctly handles offsets. @@ -57,6 +57,7 @@ def test_fortran_frontend_multiple_ranges_all(): for val in res: assert val == 1.0 + def test_fortran_frontend_multiple_ranges_selection(): """ Tests that the generated array map correctly handles offsets. @@ -93,6 +94,7 @@ def test_fortran_frontend_multiple_ranges_selection(): for idx, val in enumerate(res): assert val == idx + 1.0 + def test_fortran_frontend_multiple_ranges_selection_var(): """ Tests that the generated array map correctly handles offsets. @@ -137,6 +139,7 @@ def test_fortran_frontend_multiple_ranges_selection_var(): for idx, val in enumerate(res): assert -val == idx + 1.0 + def test_fortran_frontend_multiple_ranges_subset(): """ Tests that the generated array map correctly handles offsets. @@ -172,6 +175,7 @@ def test_fortran_frontend_multiple_ranges_subset(): for idx, val in enumerate(res): assert val == -3.0 + def test_fortran_frontend_multiple_ranges_subset_var(): """ Tests that the generated array map correctly handles offsets. @@ -202,7 +206,7 @@ def test_fortran_frontend_multiple_ranges_subset_var(): size = 9 input1 = np.full([size], 0, order="F", dtype=np.float64) for i in range(size): - input1[i] = 2 ** i + input1[i] = 2**i pos = np.full([4], 0, order="F", dtype=np.int32) pos[0] = 2 @@ -216,6 +220,7 @@ def test_fortran_frontend_multiple_ranges_subset_var(): for i in range(len(res)): assert res[i] == input1[pos[0] - 1 + i] - input1[pos[2] - 1 + i] + def test_fortran_frontend_multiple_ranges_ecrad_pattern(): """ Tests that the generated array map correctly handles offsets. @@ -247,7 +252,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern(): input1 = np.full([size, size], 0, order="F", dtype=np.float64) for i in range(size): for j in range(size): - input1[i, j] = i + 2 ** j + input1[i, j] = i + 2**j pos = np.full([2], 0, order="F", dtype=np.int32) pos[0] = 2 @@ -259,9 +264,10 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern(): for i in range(size): for j in range(pos[0], pos[1] + 1): - print(i , j, res[i - 1, j - 1], input1[i - 1, j - 1]) + print(i, j, res[i - 1, j - 1], input1[i - 1, j - 1]) assert res[i - 1, j - 1] == input1[i - 1, j - 1] + def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex(): """ Tests that the generated array map correctly handles offsets. @@ -293,7 +299,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex(): input1 = np.full([size, size], 0, order="F", dtype=np.float64) for i in range(size): for j in range(size): - input1[i, j] = i + 2 ** j + input1[i, j] = i + 2**j pos = np.full([6], 0, order="F", dtype=np.int32) pos[0] = 2 @@ -315,6 +321,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex(): for j in range(length): assert res[i - 1, iter_1 + j - 1] == input1[i - 1, iter_2 + j - 1] + input1[i - 1, iter_3 + j - 1] + def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex_offsets(): """ Tests that the generated array map correctly handles offsets. @@ -346,7 +353,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex_offsets(): input1 = np.full([size, size], 0, order="F", dtype=np.float64) for i in range(size): for j in range(size): - input1[i, j] = i + 2 ** j + input1[i, j] = i + 2**j pos = np.full([6], 0, order="F", dtype=np.int32) pos[0] = 2 + 30 @@ -368,6 +375,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_pattern_complex_offsets(): for j in range(length): assert res[i - 1, iter_1 + j - 1] == input1[i - 1, iter_2 + j - 1] + input1[i - 1, iter_3 + j - 1] + def test_fortran_frontend_array_assignment(): """ Tests that the generated array map correctly handles offsets. @@ -428,6 +436,7 @@ def test_fortran_frontend_array_assignment(): assert res[i, 3] == input1[i] + input2[i] assert res[i, 4] == input1[i] + input2[i] + def test_fortran_frontend_multiple_ranges_ecrad_bug(): """ Tests that the generated array map correctly handles offsets. @@ -462,7 +471,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_bug(): input1 = np.full([size, size], 0, order="F", dtype=np.float64) for i in range(size): for j in range(size): - input1[i, j] = i + 2 ** j + input1[i, j] = i + 2**j pos = np.full([4], 0, order="F", dtype=np.int32) pos[0] = 2 @@ -484,6 +493,7 @@ def test_fortran_frontend_multiple_ranges_ecrad_bug(): iter_1 += 1 iter_2 += 1 + def test_fortran_frontend_ranges_array_bug(): """ Tests that the generated array map correctly handles offsets. diff --git a/tests/fortran/recursive_ast_improver_test.py b/tests/fortran/recursive_ast_improver_test.py index ef9fbdf5bc..b699ad7d67 100644 --- a/tests/fortran/recursive_ast_improver_test.py +++ b/tests/fortran/recursive_ast_improver_test.py @@ -625,7 +625,8 @@ def test_floaters_are_brought_in(): """ The same simple program, but this time the subroutine is defined inside the main program that calls it. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ subroutine fun(z) implicit none real, intent(out) :: z @@ -675,7 +676,8 @@ def test_floaters_can_bring_in_more_modules(): """ The same simple program, but this time the subroutine is defined inside the main program that calls it. """ - sources, main = SourceCodeBuilder().add_file(""" + sources, main = SourceCodeBuilder().add_file( + """ module lib implicit none real, parameter :: zzz = 5.5 diff --git a/tests/fortran/rename_test.py b/tests/fortran/rename_test.py index aa1576cc5b..2f14eb49ad 100644 --- a/tests/fortran/rename_test.py +++ b/tests/fortran/rename_test.py @@ -33,9 +33,9 @@ def test_fortran_frontend_rename(): """ - sources={} - sources["rename_test"]=test_string - sources["rename_test_module_subroutine.f90"]=""" + sources = {} + sources["rename_test"] = test_string + sources["rename_test_module_subroutine.f90"] = """ MODULE rename_test_module_subroutine CONTAINS SUBROUTINE rename_test_function(d) @@ -48,14 +48,14 @@ def test_fortran_frontend_rename(): END SUBROUTINE rename_test_function END MODULE rename_test_module_subroutine """ - sources["rename_test_module.f90"]=""" + sources["rename_test_module.f90"] = """ MODULE rename_test_module IMPLICIT NONE INTEGER, PARAMETER :: pi4 = 9 INTEGER, PARAMETER :: i4 = SELECTED_INT_KIND(pi4) END MODULE rename_test_module """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "rename_test",sources=sources) + sdfg = fortran_parser.create_sdfg_from_string(test_string, "rename_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) sdfg(d=a) @@ -64,7 +64,6 @@ def test_fortran_frontend_rename(): assert (a[2] == 42) - if __name__ == "__main__": test_fortran_frontend_rename() diff --git a/tests/fortran/scope_arrays_test.py b/tests/fortran/scope_arrays_test.py index 5dd5b806a8..a8c18cf524 100644 --- a/tests/fortran/scope_arrays_test.py +++ b/tests/fortran/scope_arrays_test.py @@ -42,6 +42,7 @@ def test_fortran_frontend_parent(): assert ('scope_test_function', var) in visitor.scope_vars assert visitor.scope_vars[('scope_test_function', var)].name == var + if __name__ == "__main__": test_fortran_frontend_parent() diff --git a/tests/fortran/struct_test.py b/tests/fortran/struct_test.py index 93606f1964..ee47a7f8c9 100644 --- a/tests/fortran/struct_test.py +++ b/tests/fortran/struct_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_struct(): test_string = """ PROGRAM struct_test_range @@ -42,7 +43,7 @@ def test_fortran_struct(): END SUBROUTINE struct_test_range2_test_function """ - sources={} + sources = {} sdfg = fortran_parser.create_sdfg_from_string(test_string, "res", False, sources=sources) sdfg.save('before.sdfg') sdfg.simplify(verbose=True) @@ -55,6 +56,7 @@ def test_fortran_struct(): sdfg(res=res, start=2, end=5) print(res) + def test_fortran_struct_lhs(): test_string = """ PROGRAM struct_test_range @@ -97,7 +99,7 @@ def test_fortran_struct_lhs(): END SUBROUTINE struct_test_range2_test_function """ - sources={} + sources = {} sdfg = fortran_parser.create_sdfg_from_string(test_string, "res", False, sources=sources) sdfg.save('before.sdfg') sdfg.simplify(verbose=True) @@ -110,6 +112,7 @@ def test_fortran_struct_lhs(): sdfg(res=res, start=2, end=5) print(res) + if __name__ == "__main__": test_fortran_struct() test_fortran_struct_lhs() diff --git a/tests/fortran/sum_to_loop_offset_test.py b/tests/fortran/sum_to_loop_offset_test.py index e933589e0f..a497271ff6 100644 --- a/tests/fortran/sum_to_loop_offset_test.py +++ b/tests/fortran/sum_to_loop_offset_test.py @@ -4,6 +4,7 @@ from dace.frontend.fortran import ast_transforms, fortran_parser + def test_fortran_frontend_sum2loop_1d_without_offset(): """ Tests that the generated array map correctly handles offsets. @@ -41,7 +42,8 @@ def test_fortran_frontend_sum2loop_1d_without_offset(): sdfg(d=d, res=res) assert res[0] == (1 + size) * size / 2 assert res[1] == (1 + size) * size / 2 - assert res[2] == (2 + size - 1) * (size - 2)/ 2 + assert res[2] == (2 + size - 1) * (size - 2) / 2 + def test_fortran_frontend_sum2loop_1d_offset(): """ @@ -82,6 +84,7 @@ def test_fortran_frontend_sum2loop_1d_offset(): assert res[1] == (1 + size) * size / 2 assert res[2] == (2 + size - 1) * (size - 2) / 2 + def test_fortran_frontend_arr2loop_2d(): """ Tests that the generated array map correctly handles offsets. @@ -126,6 +129,7 @@ def test_fortran_frontend_arr2loop_2d(): assert res[2] == 21 assert res[3] == 45 + def test_fortran_frontend_arr2loop_2d_offset(): """ Tests that the generated array map correctly handles offsets. @@ -168,6 +172,7 @@ def test_fortran_frontend_arr2loop_2d_offset(): assert res[1] == 190 assert res[2] == 57 + if __name__ == "__main__": test_fortran_frontend_sum2loop_1d_without_offset() diff --git a/tests/fortran/tasklet_test.py b/tests/fortran/tasklet_test.py index 263c49b922..49a2f5ac79 100644 --- a/tests/fortran/tasklet_test.py +++ b/tests/fortran/tasklet_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_tasklet(): test_string = """ PROGRAM tasklet @@ -33,9 +34,9 @@ def test_fortran_frontend_tasklet(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "tasklet", normalize_offsets=True) sdfg.view() sdfg.simplify(verbose=True) - + sdfg.compile() - + input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) diff --git a/tests/fortran/type_array_test.py b/tests/fortran/type_array_test.py index e54846c096..34407b7790 100644 --- a/tests/fortran/type_array_test.py +++ b/tests/fortran/type_array_test.py @@ -20,6 +20,7 @@ from dace.transformation.passes.lift_struct_views import LiftStructViews from dace.transformation import pass_pipeline as ppl + def test_fortran_frontend_type_array(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -63,14 +64,18 @@ def test_fortran_frontend_type_array(): END SUBROUTINE deepest """ - sources={} - sources["type_array_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_array_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type_array_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_array_test", + sources=sources, + normalize_offsets=True) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) print(a) + def test_fortran_frontend_type2_array(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -121,14 +126,18 @@ def test_fortran_frontend_type2_array(): END SUBROUTINE deepest """ - sources={} - sources["type2_array_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type2_array_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type2_array_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type2_array_test", + sources=sources, + normalize_offsets=True) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) print(a) + def test_fortran_frontend_type3_array(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -206,9 +215,12 @@ def test_fortran_frontend_type3_array(): END SUBROUTINE deepest """ - sources={} - sources["type3_array_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type3_array_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type3_array_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type3_array_test", + sources=sources, + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() #a = np.full([5, 5], 42, order="F", dtype=np.float32) @@ -216,9 +228,7 @@ def test_fortran_frontend_type3_array(): #print(a) - - if __name__ == "__main__": - - #test_fortran_frontend_type_array() - test_fortran_frontend_type3_array() + + #test_fortran_frontend_type_array() + test_fortran_frontend_type3_array() diff --git a/tests/fortran/type_test.py b/tests/fortran/type_test.py index 6cdc0c06b4..05c5bd55ee 100644 --- a/tests/fortran/type_test.py +++ b/tests/fortran/type_test.py @@ -46,9 +46,9 @@ def test_fortran_frontend_basic_type(): d(2,1) = 5.5 + s%w(1,1,1) END SUBROUTINE type_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -57,7 +57,6 @@ def test_fortran_frontend_basic_type(): assert (a[2, 0] == 42) - def test_fortran_frontend_basic_type2(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -145,7 +144,9 @@ def test_fortran_frontend_type_symbol(): END SUBROUTINE internal_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_symbol_test",sources={"type_symbol_test":test_string}) + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_symbol_test", + sources={"type_symbol_test": test_string}) sdfg.validate() sdfg.simplify(verbose=True) a = np.full([4, 5], 42, order="F", dtype=np.float64) @@ -154,6 +155,7 @@ def test_fortran_frontend_type_symbol(): assert (a[1, 0] == 11) assert (a[2, 0] == 42) + def test_fortran_frontend_type_pardecl(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -194,7 +196,9 @@ def test_fortran_frontend_type_pardecl(): END SUBROUTINE internal_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_pardecl_test",sources={"type_pardecl_test":test_string}) + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_pardecl_test", + sources={"type_pardecl_test": test_string}) sdfg.validate() sdfg.simplify(verbose=True) a = np.full([4, 5], 42, order="F", dtype=np.float32) @@ -203,6 +207,7 @@ def test_fortran_frontend_type_pardecl(): assert (a[1, 0] == 11) assert (a[2, 0] == 42) + def test_fortran_frontend_type_struct(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -246,7 +251,9 @@ def test_fortran_frontend_type_struct(): END SUBROUTINE internal_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_struct_test",sources={"type_struct_test":test_string}) + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_struct_test", + sources={"type_struct_test": test_string}) sdfg.validate() sdfg.simplify(verbose=True) a = np.full([4, 5], 42, order="F", dtype=np.float32) @@ -312,7 +319,6 @@ def test_fortran_frontend_circular_type(): assert (a[2, 0] == 42) - def test_fortran_frontend_type_in_call(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -340,9 +346,9 @@ def test_fortran_frontend_type_in_call(): d(2,1) = max(1.0, tmp(1,1,1)) END SUBROUTINE type_in_call_test_function """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -350,6 +356,7 @@ def test_fortran_frontend_type_in_call(): assert (a[1, 0] == 11) assert (a[2, 0] == 42) + def test_fortran_frontend_type_array(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -388,9 +395,12 @@ def test_fortran_frontend_type_array(): s%name%w(8,10)%a = 42 END SUBROUTINE type_in_call_test_function2 """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_in_call_test", + sources=sources, + normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.save('test.sdfg') sdfg.compile() @@ -399,6 +409,7 @@ def test_fortran_frontend_type_array(): sdfg(d=a) print(a) + def test_fortran_frontend_type_array2(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -443,9 +454,12 @@ def test_fortran_frontend_type_array2(): s%name%wx(8,x(3,3,3)) = 43 END SUBROUTINE type_in_call_test_function2 """ - sources={} - sources["type_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_in_call_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_in_call_test", + sources=sources, + normalize_offsets=True) sdfg.save("before.sdfg") sdfg.simplify(verbose=True) sdfg.save("after.sdfg") @@ -455,6 +469,7 @@ def test_fortran_frontend_type_array2(): sdfg(d=a) print(a) + def test_fortran_frontend_type_pointer(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -482,9 +497,9 @@ def test_fortran_frontend_type_pointer(): d(2,1) = max(1.0, tmp(1,1,1)) END SUBROUTINE type_pointer_test_function """ - sources={} - sources["type_pointer_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_pointer_test",sources=sources) + sources = {} + sources["type_pointer_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_pointer_test", sources=sources) sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) sdfg(d=a) @@ -536,9 +551,9 @@ def test_fortran_frontend_type_arg(): END SUBROUTINE deepest """ - sources={} - sources["type_arg_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_arg_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type_arg_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_arg_test", sources=sources, normalize_offsets=True) sdfg.view() sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) @@ -546,7 +561,6 @@ def test_fortran_frontend_type_arg(): print(a) - def test_fortran_frontend_type_arg2(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -587,9 +601,12 @@ def test_fortran_frontend_type_arg2(): END SUBROUTINE deepest """ - sources={} - sources["type_arg2_test"]=test_string - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_arg2_test",sources=sources, normalize_offsets=True) + sources = {} + sources["type_arg2_test"] = test_string + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_arg2_test", + sources=sources, + normalize_offsets=True) sdfg.save("before.sdfg") sdfg.simplify(verbose=True) a = np.full([5, 5], 42, order="F", dtype=np.float32) @@ -597,7 +614,6 @@ def test_fortran_frontend_type_arg2(): print(a) - def test_fortran_frontend_type_view(): """ Tests that the Fortran frontend can parse the simplest type declaration and make use of it in a computation. @@ -632,7 +648,10 @@ def test_fortran_frontend_type_view(): END SUBROUTINE internal_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, "type_view_test",sources={"type_view_test":test_string},normalize_offsets=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, + "type_view_test", + sources={"type_view_test": test_string}, + normalize_offsets=True) sdfg.validate() sdfg.simplify(verbose=True) a = np.full([4, 5], 42, order="F", dtype=np.float32) diff --git a/tests/fortran/while_test.py b/tests/fortran/while_test.py index 96a43efef7..5f6d73f513 100644 --- a/tests/fortran/while_test.py +++ b/tests/fortran/while_test.py @@ -5,6 +5,7 @@ from dace.frontend.fortran import fortran_parser + def test_fortran_frontend_while(): test_string = """ PROGRAM while @@ -33,7 +34,7 @@ def test_fortran_frontend_while(): sdfg = fortran_parser.create_sdfg_from_string(test_string, "while", normalize_offsets=True) sdfg.simplify(verbose=True) sdfg.compile() - + input = np.full([2], 42, order="F", dtype=np.float32) res = np.full([2], 42, order="F", dtype=np.float32) sdfg(d=input, res=res) From bd3060665b336af14d0c6bf007959a91f2f0a5c3 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 10:21:52 +0100 Subject: [PATCH 05/12] Copyright, more formatting --- dace/frontend/fortran/ast_components.py | 2 +- dace/frontend/fortran/ast_desugaring.py | 2 ++ dace/frontend/fortran/ast_internal_classes.py | 2 +- dace/frontend/fortran/ast_transforms.py | 2 +- dace/frontend/fortran/ast_utils.py | 24 +++++++++---------- dace/frontend/fortran/fortran_parser.py | 4 ++-- .../fortran/icon_config_propagation.py | 6 +---- dace/frontend/fortran/intrinsics.py | 8 ++++--- tests/fortran/array_test.py | 3 +-- tests/fortran/empty_test.py | 2 +- 10 files changed, 26 insertions(+), 29 deletions(-) diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index 49a28a3650..0e62ab91b9 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from typing import Any, List, Optional, Type, TypeVar, Union, overload, TYPE_CHECKING, Dict import networkx as nx diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 7d8b3d58e1..bfcc842bc5 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -1,3 +1,5 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import math import operator import re diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index 54475892b6..643aad802b 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from typing import List, Optional, Tuple, Union, Dict, Any # The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields. diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 8a69e9e2a8..de0ef6c259 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from typing import Dict, List, Optional, Tuple, Set, Union diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index 5251fc10bc..378d647895 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -1,26 +1,24 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. -from itertools import chain -from typing import List, Set, Iterator, Type, TypeVar, Dict, Tuple, Iterable, Union, Optional - -import networkx as nx -from fparser.two.Fortran2003 import Module_Stmt, Name, Interface_Block, Subroutine_Stmt, Specification_Part, Module, \ - Derived_Type_Def, Function_Stmt, Interface_Stmt, Function_Body, Type_Name, Rename, Entity_Decl, Kind_Selector, \ - Intrinsic_Type_Spec, Use_Stmt, Declaration_Type_Spec -from fparser.two.Fortran2008 import Type_Declaration_Stmt, Procedure_Stmt +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import (Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union) + +from fparser.two.Fortran2003 import (Derived_Type_Def, Function_Body, Function_Stmt, Interface_Block, Interface_Stmt, + Module, Module_Stmt, Name, Rename, Specification_Part, Subroutine_Stmt, Type_Name, + Use_Stmt) +from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt from fparser.two.utils import Base from numpy import finfo as finf from numpy import float64 as fl +# dace imports from dace import DebugInfo as di from dace import Language as lang from dace import Memlet from dace import data as dat -from dace import dtypes -# dace imports -from dace import subsets +from dace import dtypes, subsets from dace import symbolic as sym from dace.frontend.fortran import ast_internal_classes -from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace.sdfg import SDFG, InterstateEdge, SDFGState from dace.sdfg.nodes import Tasklet fortrantypes2dacetypes = { diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 36cc2bbf30..56b61c4d52 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -1,4 +1,4 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy import os @@ -31,7 +31,7 @@ from dace import subsets as subs from dace import symbolic as sym from dace.data import Scalar, Structure -from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, ENTRY_POINT_OBJECT_TYPES, NAMED_STMTS_OF_INTEREST, +from dace.frontend.fortran.ast_desugaring import (ENTRY_POINT_OBJECT, NAMED_STMTS_OF_INTEREST, SPEC, append_children, consolidate_uses, const_eval_nodes, correct_for_function_calls, deconstruct_associations, deconstruct_enums, deconstruct_interface_calls, diff --git a/dace/frontend/fortran/icon_config_propagation.py b/dace/frontend/fortran/icon_config_propagation.py index 5b978b7b6b..9fc876df02 100644 --- a/dace/frontend/fortran/icon_config_propagation.py +++ b/dace/frontend/fortran/icon_config_propagation.py @@ -1,4 +1,4 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import os import sys @@ -15,10 +15,6 @@ from dace.frontend.fortran import fortran_parser -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_internal_classes as ast_internal - def find_path_recursive(base_dir): dirs = os.listdir(base_dir) diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index 8488a94837..52a97e2b12 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -1,3 +1,5 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import copy import math import sys @@ -6,13 +8,13 @@ from typing import Any, List, Optional, Tuple, Union from dace.frontend.fortran import ast_internal_classes -from dace.frontend.fortran.ast_transforms import NodeVisitor, NodeTransformer, ParentScopeAssigner, \ - ScopeVarsDeclarations, TypeInference, par_Decl_Range_Finder, mywalk +from dace.frontend.fortran.ast_transforms import (NodeTransformer, NodeVisitor, ParentScopeAssigner, + ScopeVarsDeclarations, mywalk, par_Decl_Range_Finder) from dace.frontend.fortran.ast_utils import fortrantypes2dacetypes from dace.libraries.blas.nodes.dot import dot_libnode from dace.libraries.blas.nodes.gemm import gemm_libnode from dace.libraries.standard.nodes import Transpose -from dace.sdfg import SDFGState, SDFG, nodes +from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg.graph import OrderedDiGraph from dace.transformation import transformation as xf diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index 61090457d0..e83b8367ba 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -1,9 +1,8 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np from dace import dtypes, symbolic -from dace.frontend.fortran.fortran_parser import create_sdfg_from_string from dace.sdfg import utils as sdutil from dace.sdfg.nodes import AccessNode from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string diff --git a/tests/fortran/empty_test.py b/tests/fortran/empty_test.py index 9ec83f4d47..e9e9382ba1 100644 --- a/tests/fortran/empty_test.py +++ b/tests/fortran/empty_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np From 523c1d7ce6a6abd61f14333597a25cbead6fb14c Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 20 Dec 2024 12:17:47 +0100 Subject: [PATCH 06/12] Adapt frontend to control flow regions, first step --- dace/frontend/fortran/ast_utils.py | 43 +--- dace/frontend/fortran/fortran_parser.py | 290 ++++++++++++------------ 2 files changed, 159 insertions(+), 174 deletions(-) diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index 378d647895..5487f2a891 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -1,24 +1,25 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from typing import (Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union) - -from fparser.two.Fortran2003 import (Derived_Type_Def, Function_Body, Function_Stmt, Interface_Block, Interface_Stmt, - Module, Module_Stmt, Name, Rename, Specification_Part, Subroutine_Stmt, Type_Name, - Use_Stmt) +from typing import (Dict, Iterable, Iterator, List, Optional, + Set, Tuple, Type, TypeVar, Union) + +from fparser.two.Fortran2003 import (Derived_Type_Def, Function_Body, + Function_Stmt, Interface_Block, + Interface_Stmt, Module, Module_Stmt, Name, + Rename, Specification_Part, + Subroutine_Stmt, Type_Name, Use_Stmt) from fparser.two.Fortran2008 import Procedure_Stmt, Type_Declaration_Stmt from fparser.two.utils import Base from numpy import finfo as finf from numpy import float64 as fl # dace imports -from dace import DebugInfo as di -from dace import Language as lang from dace import Memlet from dace import data as dat from dace import dtypes, subsets from dace import symbolic as sym from dace.frontend.fortran import ast_internal_classes -from dace.sdfg import SDFG, InterstateEdge, SDFGState +from dace.sdfg import SDFG, SDFGState from dace.sdfg.nodes import Tasklet fortrantypes2dacetypes = { @@ -32,17 +33,6 @@ } -def add_tasklet(substate: SDFGState, name: str, vars_in: Set[str], vars_out: Set[str], code: str, debuginfo: list, - source: str): - tasklet = substate.add_tasklet(name="T" + name, - inputs=vars_in, - outputs=vars_out, - code=code, - debuginfo=di(start_line=debuginfo[0], start_column=debuginfo[1], filename=source), - language=lang.Python) - return tasklet - - def add_memlet_read(substate: SDFGState, var_name: str, tasklet: Tasklet, dest_conn: str, memlet_range: str): found = False if isinstance(substate.parent.arrays[var_name], dat.View): @@ -80,21 +70,6 @@ def add_memlet_write(substate: SDFGState, var_name: str, tasklet: Tasklet, sourc return dst -def add_simple_state_to_sdfg(state: SDFGState, top_sdfg: SDFG, state_name: str): - if state.last_sdfg_states.get(top_sdfg) is not None: - substate = top_sdfg.add_state(state_name) - else: - substate = top_sdfg.add_state(state_name, is_start_state=True) - finish_add_state_to_sdfg(state, top_sdfg, substate) - return substate - - -def finish_add_state_to_sdfg(state: SDFGState, top_sdfg: SDFG, substate: SDFGState): - if state.last_sdfg_states.get(top_sdfg) is not None: - top_sdfg.add_edge(state.last_sdfg_states[top_sdfg], substate, InterstateEdge()) - state.last_sdfg_states[top_sdfg] = substate - - def get_name(node: ast_internal_classes.FNode): if isinstance(node, ast_internal_classes.Actual_Arg_Spec_Node): actual_node = node.arg diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 56b61c4d52..4ffb5ac578 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -42,6 +42,7 @@ from dace.frontend.fortran.ast_utils import children_of_type from dace.frontend.fortran.intrinsics import (IntrinsicSDFGTransformation, NeedsTypeInferenceException) from dace.properties import CodeBlock +from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowRegion, LoopRegion global_struct_instance_counter = 0 @@ -294,10 +295,6 @@ def __init__(self, self.unallocated_arrays = [] self.all_array_names = [] self.last_sdfg_states = {} - self.last_loop_continues = {} - self.last_loop_continues_stack = {} - self.already_has_edge_back_continue = {} - self.last_loop_breaks = {} self.last_returns = {} self.module_vars = [] self.sdfgs_count = 0 @@ -396,28 +393,54 @@ def get_memlet_range(self, sdfg: SDFG, variables: List[ast_internal_classes.FNod if o_v.name == var_name_tasklet: return ast_utils.generate_memlet(o_v, sdfg, self, self.normalize_offsets) - def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG): + + def _add_tasklet(self, substate: SDFGState, name: str, vars_in: Set[str], vars_out: Set[str], code: str, + debuginfo: list, source: str): + tasklet = substate.add_tasklet(name="T" + name, inputs=vars_in, outputs=vars_out, code=code, + debuginfo=dtypes.DebugInfo(start_line=debuginfo[0], start_column=debuginfo[1], + filename=source), language=dtypes.Language.Python) + return tasklet + + + def _add_simple_state_to_cfg(self, cfg: ControlFlowRegion, state_name: str): + if cfg in self.last_sdfg_states and self.last_sdfg_states[cfg] is not None: + substate = cfg.add_state(state_name) + else: + substate = cfg.add_state(state_name, is_start_state=True) + self._finish_add_state_to_cfg(cfg, substate) + return substate + + + def _finish_add_state_to_cfg(self, cfg: ControlFlowRegion, substate: SDFGState): + if cfg in self.last_sdfg_states and self.last_sdfg_states[cfg] is not None: + cfg.add_edge(self.last_sdfg_states[cfg], substate, InterstateEdge()) + self.last_sdfg_states[cfg] = substate + + + def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating the AST into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated :note: This function is recursive and will call itself for all child nodes :note: This function will call the appropriate function for the node type :note: The dictionary ast_elements, part of the class itself contains all functions that are called for the different node types """ if node.__class__ in self.ast_elements: - self.ast_elements[node.__class__](node, sdfg) + self.ast_elements[node.__class__](node, sdfg, cfg) elif isinstance(node, list): for i in node: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) else: warnings.warn(f"WARNING: {node.__class__.__name__}") - def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): + def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating the Fortran AST into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated :note: This function is recursive and will call itself for all child nodes :note: This function will call the appropriate function for the node type :note: The dictionary ast_elements, part of the class itself contains all functions that are called for the different node types @@ -444,7 +467,7 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): for jj in parse_order: for j in i.specification_part.typedecls: if j.name.name == jj: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) if j.__class__.__name__ != "Derived_Type_Def_Node": for k in j.vardecl: self.module_vars.append((k.name, i.name)) @@ -455,7 +478,7 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): self.transient_mode = self.do_not_make_internal_variables_argument for j in i.specification_part.symbols: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node): self.module_vars.append((j.name, i.name)) elif isinstance(j, ast_internal_classes.Symbol_Decl_Node): @@ -463,13 +486,13 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): else: raise ValueError("Unknown symbol type") for j in i.specification_part.specifications: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) for k in j.vardecl: self.module_vars.append((k.name, i.name)) # this works with CloudSC # unsure about ICON self.transient_mode = True - ast_utils.add_simple_state_to_sdfg(self, sdfg, "GlobalDefEnd") + self._add_simple_state_to_cfg(cfg, "GlobalDefEnd") if self.startpoint is None: self.startpoint = node.main_program assert self.startpoint is not None, "No main program or start point found" @@ -480,15 +503,15 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): self.transient_mode = self.do_not_make_internal_variables_argument for i in self.startpoint.specification_part.typedecls: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) for i in self.startpoint.specification_part.symbols: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) for i in self.startpoint.specification_part.specifications: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) for i in self.startpoint.specification_part.specifications: - ast_utils.add_simple_state_to_sdfg(self, sdfg, "start_struct_size") - assign_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "assign_struct_sizes") + self._add_simple_state_to_cfg(cfg, "start_struct_size") + assign_state = self._add_simple_state_to_cfg(cfg, "assign_struct_sizes") for decl in i.vardecl: if decl.name in sdfg.symbols: continue @@ -508,13 +531,15 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): arr.transient = False self.transient_mode = True - self.translate(self.startpoint.execution_part.execution, sdfg) + self.translate(self.startpoint.execution_part.execution, sdfg, cfg) - def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG): + def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG, + cfg: ControlFlowRegion): """ This function is responsible for translating Fortran pointer assignments into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated """ if self.name_mapping[sdfg][node.name_pointer.name] in sdfg.arrays: shapenames = [ @@ -591,11 +616,13 @@ def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_S self.unallocated_arrays.remove(i) self.name_mapping[sdfg][node.name_pointer.name] = self.name_mapping[sdfg][node.name_target.name] - def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG): + def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, sdfg: SDFG, + cfg: ControlFlowRegion): """ This function is responsible for registering Fortran derived type declarations into a SDFG as nested data types. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated """ name = node.name.name if node.component_part is None: @@ -644,21 +671,23 @@ def derivedtypedef2sdfg(self, node: ast_internal_classes.Derived_Type_Def_Node, structure_obj = Structure(dict_setup, name) self.registered_types[name] = structure_obj - def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG): + def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran basic blocks into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated """ for i in node.execution: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) - def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG): + def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran allocate statements into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated :note: We pair the allocate with a list of unallocated arrays. """ for i in node.allocation_list: @@ -695,20 +724,21 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF strides=strides, transient=transient) - def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG): + def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): # TODO implement print("Uh oh") # raise NotImplementedError("Fortran write statements are not implemented yet") - def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): + def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran if statements into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region into which the node should be translated """ name = f"If_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"Begin{name}") + begin_state = self._add_simple_state_to_cfg(sdfg, f"Begin{name}") guard_substate = sdfg.add_state(f"Guard{name}") sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) @@ -728,7 +758,7 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): self.last_returns.get(sdfg), self.already_has_edge_back_continue.get(sdfg) ]: - body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyIfEnd{name}") + body_ifend_state = self._add_simple_state_to_cfg(sdfg, f"BodyIfEnd{name}") sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) if len(node.body_else.execution) > 0: @@ -736,23 +766,22 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): body_elsestart_state = sdfg.add_state("BodyElseStart" + name_else) self.last_sdfg_states[sdfg] = body_elsestart_state self.translate(node.body_else, sdfg) - body_elseend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyElseEnd{name_else}") + body_elseend_state = self._add_simple_state_to_cfg(sdfg, f"BodyElseEnd{name_else}") sdfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) sdfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) else: sdfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) self.last_sdfg_states[sdfg] = final_substate - def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG): + def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): + """ + This function is responsible for translating Fortran while statements into a SDFG. + :param node: The while statement node to be translated + :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated + """ - # raise NotImplementedError("Fortran while statements are not implemented yet") name = "While_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) - begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "Begin" + name) - guard_substate = sdfg.add_state("Guard" + name) - final_substate = sdfg.add_state("Merge" + name) - self.last_sdfg_states[sdfg] = final_substate - - sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, @@ -760,41 +789,28 @@ def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG) placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.cond) - begin_loop_state = sdfg.add_state("BeginWhile" + name) - end_loop_state = sdfg.add_state("EndWhile" + name) - self.last_sdfg_states[sdfg] = begin_loop_state - self.last_loop_continues[sdfg] = end_loop_state - if self.last_loop_continues_stack.get(sdfg) is None: - self.last_loop_continues_stack[sdfg] = [] - self.last_loop_continues_stack[sdfg].append(end_loop_state) - self.translate(node.body, sdfg) + loop_region = LoopRegion(name, condition, inverted=False, sdfg=sdfg) - sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) - sdfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) - sdfg.add_edge(end_loop_state, guard_substate, InterstateEdge()) - sdfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) - self.last_sdfg_states[sdfg] = final_substate + is_start = cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None + cfg.add_node(loop_region, is_start_block=is_start) + if not is_start: + cfg.add_edge(self.last_sdfg_states[cfg], loop_region, InterstateEdge()) + self.last_sdfg_states[cfg] = loop_region + self.last_sdfg_states[loop_region] = loop_region.add_state('BeginLoop_' + name, is_start_block=True) - if len(self.last_loop_continues_stack[sdfg]) > 0: - self.last_loop_continues[sdfg] = self.last_loop_continues_stack[sdfg][-1] - else: - self.last_loop_continues[sdfg] = None + self.translate(node.body, sdfg, loop_region) - def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): + def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran for statements into a SDFG. - :param node: The node to be translated + :param node: The for statement node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated """ - declloop = False - name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) - begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "Begin" + name) - guard_substate = sdfg.add_state("Guard" + name) - final_substate = sdfg.add_state("Merge" + name) - self.last_sdfg_states[sdfg] = final_substate + name = 'FOR_l_' + str(node.line_number[0]) + '_c_' + str(node.line_number[1]) decl_node = node.init - entry = {} + init_expr = None if isinstance(decl_node, ast_internal_classes.BinOp_Node): if sdfg.symbols.get(decl_node.lval.name) is not None: iter_name = decl_node.lval.name @@ -802,13 +818,12 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): iter_name = self.name_mapping[sdfg][decl_node.lval.name] else: raise ValueError("Unknown variable " + decl_node.lval.name) - entry[iter_name] = ast_utils.ProcessedWriter(sdfg, - self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code(decl_node.rval) - - sdfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) + init_assignment = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(decl_node.rval) + init_expr = f'{iter_name} = {init_assignment}' condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, @@ -816,40 +831,32 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): placeholders_offsets=self.placeholders_offsets, rename_dict=self.replace_names).write_code(node.cond) - increment = "i+0+1" + increment_expr = 'i+0+1' if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment = ast_utils.ProcessedWriter(sdfg, - self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code(node.iter.rval) - entry = {iter_name: increment} - - begin_loop_state = sdfg.add_state("BeginLoop" + name) - end_loop_state = sdfg.add_state("EndLoop" + name) - self.last_sdfg_states[sdfg] = begin_loop_state - self.last_loop_continues[sdfg] = end_loop_state - if self.last_loop_continues_stack.get(sdfg) is None: - self.last_loop_continues_stack[sdfg] = [] - self.last_loop_continues_stack[sdfg].append(end_loop_state) - self.translate(node.body, sdfg) + increment_expr = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.iter.rval) + + loop_region = LoopRegion(name, condition, iter_name, init_expr, increment_expr, inverted=False, sdfg=sdfg) + + is_start = cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None + cfg.add_node(loop_region, is_start_block=is_start) + if not is_start: + cfg.add_edge(self.last_sdfg_states[cfg], loop_region, InterstateEdge()) + self.last_sdfg_states[cfg] = loop_region + self.last_sdfg_states[loop_region] = loop_region.add_state('BeginLoop_' + name, is_start_block=True) + + self.translate(node.body, sdfg, loop_region) - sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) - sdfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) - sdfg.add_edge(end_loop_state, guard_substate, InterstateEdge(assignments=entry)) - sdfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) - self.last_sdfg_states[sdfg] = final_substate - self.last_loop_continues_stack[sdfg].pop() - if len(self.last_loop_continues_stack[sdfg]) > 0: - self.last_loop_continues[sdfg] = self.last_loop_continues_stack[sdfg][-1] - else: - self.last_loop_continues[sdfg] = None - def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): + def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran symbol declarations into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated """ if node.name == "modname": return @@ -888,11 +895,11 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): datatype = self.get_dace_type(node.type) if node.name not in sdfg.symbols: sdfg.add_symbol(node.name, datatype) - if self.last_sdfg_states.get(sdfg) is None: - bstate = sdfg.add_state("SDFGbegin", is_start_state=True) - self.last_sdfg_states[sdfg] = bstate + if cfg not in self.last_cfg_states or self.last_sdfg_states[cfg] is None: + bstate = cfg.add_state("SDFGbegin", is_start_state=True) + self.last_sdfg_states[cfg] = bstate if node.init is not None: - substate = sdfg.add_state(f"Dummystate_{node.name}") + substate = cfg.add_state(f"Dummystate_{node.name}") increment = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping, @@ -901,20 +908,22 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): rename_dict=self.replace_names).write_code(node.init) entry = {node.name: increment} - sdfg.add_edge(self.last_sdfg_states[sdfg], substate, InterstateEdge(assignments=entry)) - self.last_sdfg_states[sdfg] = substate + cfg.add_edge(self.last_sdfg_states[cfg], substate, InterstateEdge(assignments=entry)) + self.last_sdfg_states[cfg] = substate - def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG): + def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): return NotImplementedError( "Symbol_Decl_Node not implemented. This should be done via a transformation that itemizes the constant array." ) - def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG): + def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG, + cfg: ControlFlowRegion): """ This function is responsible for translating Fortran subroutine declarations into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated """ if node.execution_part is None: @@ -941,7 +950,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.sdfgs_count += 1 self.actual_offsets_per_sdfg[new_sdfg] = {} self.names_of_object_in_parent_sdfg[new_sdfg] = {} - substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "state" + my_name_sdfg) + substate = self._add_simple_state_to_cfg(cfg, "state" + my_name_sdfg) variables_in_call = [] if self.last_call_expression.get(sdfg) is not None: @@ -2129,13 +2138,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, internal_sdfg.path = self.sdfg_path + new_sdfg.name + ".sdfg" # new_sdfg.save(path.join(self.sdfg_path, new_sdfg.name + ".sdfg")) - def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): + def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ - This parses binary operations to tasklets in a new state or creates - a function call with a nested SDFG if the operation is a function - call rather than a simple assignment. + This parses binary operations to tasklets in a new state or creates a function call with a nested SDFG if the + operation is a function call rather than a simple assignment. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated """ calls = ast_transforms.FindFunctionCalls() @@ -2148,7 +2157,7 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): ]: augmented_call.args.append(node.lval) augmented_call.hasret = True - self.call2sdfg(augmented_call, sdfg) + self.call2sdfg(augmented_call, sdfg, cfg) return outputnodefinder = ast_transforms.FindOutputs(thourough=False) @@ -2181,14 +2190,13 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): input_names.append(mapped_name) input_names_tasklet.append(i.name + "_" + str(count) + "_in") - substate = ast_utils.add_simple_state_to_sdfg( - self, sdfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) + substate = self._add_simple_state_to_cfg( + cfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) output_names_changed = [o_t + "_out" for o_t in output_names] - tasklet = ast_utils.add_tasklet(substate, "_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1]), - input_names_tasklet, output_names_changed, "text", node.line_number, - self.file_name) + tasklet = self._add_tasklet(substate, "_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1]), + input_names_tasklet, output_names_changed, "text", node.line_number, self.file_name) for i, j in zip(input_names, input_names_tasklet): memlet_range = self.get_memlet_range(sdfg, input_vars, i, j) @@ -2224,12 +2232,13 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): # print(sdfg.name,node.line_number,output_names,output_names_changed,input_names,input_names_tasklet) tasklet.code = CodeBlock(text, lang.Python) - def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): + def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This parses function calls to a nested SDFG or creates a tasklet with an external library call. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated + :param cfg: The control flow region to which the node should be translated """ self.last_call_expression[sdfg] = node.args @@ -2241,20 +2250,20 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): for i in self.top_level.function_definitions: if i.name.name == node.name.name: - self.function2sdfg(i, sdfg) + self.function2sdfg(i, sdfg, cfg) return for i in self.top_level.subroutine_definitions: if i.name.name == node.name.name: - self.subroutine2sdfg(i, sdfg) + self.subroutine2sdfg(i, sdfg, cfg) return for j in self.top_level.modules: for i in j.function_definitions: if i.name.name == node.name.name: - self.function2sdfg(i, sdfg) + self.function2sdfg(i, sdfg, cfg) return for i in j.subroutine_definitions: if i.name.name == node.name.name: - self.subroutine2sdfg(i, sdfg) + self.subroutine2sdfg(i, sdfg, cfg) return else: # This part handles the case that it's an external library call @@ -2321,9 +2330,9 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): else: text = tw.write_code(node) - substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "_state" + str(node.line_number[0])) + substate = self._add_simple_state_to_cfg(cfg, "_state" + str(node.line_number[0])) - tasklet = ast_utils.add_tasklet(substate, str(node.line_number[0]), { + tasklet = self._add_tasklet(substate, str(node.line_number[0]), { **input_names_tasklet, **special_list_in }, output_names_changed + special_list_out, "text", node.line_number, self.file_name) @@ -2356,22 +2365,23 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): setattr(tasklet, "code", CodeBlock(text, lang.Python)) - def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG): + def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function translates a variable declaration statement to an access node on the sdfg :param node: The node to translate :param sdfg: The sdfg to attach the access node to + :param cfg: The control flow region to which the node should be translated :note This function is the top level of the declaration, most implementation is in vardecl2sdfg """ for i in node.vardecl: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) - def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): + def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function translates a variable declaration to an access node on the sdfg :param node: The node to translate :param sdfg: The sdfg to attach the access node to - + :param cfg: The control flow region to which the node should be translated """ if node.name == "modname": return @@ -2605,17 +2615,17 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name=node.name, type=node.type), op="=", rval=node.init, - line_number=node.line_number), sdfg) - - def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG): + line_number=node.line_number), sdfg, cfg) - self.last_loop_breaks[sdfg] = self.last_sdfg_states[sdfg] - sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) + def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG, cfg: ControlFlowRegion): + break_block = BreakBlock(f'Break_l_{node.line_number}') + cfg.add_node(break_block) + cfg.add_edge(self.last_sdfg_states[cfg], break_block, InterstateEdge()) - def continue2sdfg(self, node: ast_internal_classes.Continue_Node, sdfg: SDFG): - # - sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) - self.already_has_edge_back_continue[sdfg] = self.last_sdfg_states[sdfg] + def continue2sdfg(self, node: ast_internal_classes.Continue_Node, sdfg: SDFG, cfg: ControlFlowRegion): + continue_block = ContinueBlock(f'Continue_l_{node.line_number}') + cfg.add_node(continue_block) + cfg.add_edge(self.last_sdfg_states[cfg], continue_block, InterstateEdge()) def create_ast_from_string(source_string: str, @@ -2891,7 +2901,7 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr ast2sdfg.actual_offsets_per_sdfg[g] = {} ast2sdfg.top_level = program ast2sdfg.globalsdfg = g - ast2sdfg.translate(program, g) + ast2sdfg.translate(program, g, g) g.apply_transformations(IntrinsicSDFGTransformation) g.expand_library_nodes() gmap[ep] = g @@ -3016,7 +3026,7 @@ def create_sdfg_from_string( ast2sdfg.actual_offsets_per_sdfg[sdfg] = {} ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg - ast2sdfg.translate(program, sdfg) + ast2sdfg.translate(program, sdfg, sdfg) for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, nodes.NestedSDFG): @@ -3064,7 +3074,7 @@ def create_sdfg_from_fortran_file(source_string: str): sdfg = SDFG(source_string) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg - ast2sdfg.translate(program, sdfg) + ast2sdfg.translate(program, sdfg, sdfg) sdfg.apply_transformations(IntrinsicSDFGTransformation) sdfg.expand_library_nodes() @@ -3621,7 +3631,7 @@ def create_sdfg_from_fortran_file_with_options(cfg: ParseConfig, ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg - ast2sdfg.translate(program, sdfg) + ast2sdfg.translate(program, sdfg, sdfg) print(f'Saving SDFG {os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz")}') sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"), compress=True) @@ -3686,7 +3696,7 @@ def create_sdfg_from_fortran_file_with_options(cfg: ParseConfig, ast2sdfg.actual_offsets_per_sdfg[sdfg] = {} ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg - ast2sdfg.translate(program, sdfg) + ast2sdfg.translate(program, sdfg, sdfg) sdfg.save(os.path.join(sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"), compress=True) From 36c629198b79e54c62e01731c61daaa7439428b6 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 23 Dec 2024 09:10:22 +0100 Subject: [PATCH 07/12] Fixes --- dace/frontend/fortran/fortran_parser.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 4ffb5ac578..57bc475c20 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -752,12 +752,7 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: sdfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) - if self.last_sdfg_states[sdfg] not in [ - self.last_loop_breaks.get(sdfg), - self.last_loop_continues.get(sdfg), - self.last_returns.get(sdfg), - self.already_has_edge_back_continue.get(sdfg) - ]: + if self.last_sdfg_states[sdfg] not in self.last_returns.get(sdfg): body_ifend_state = self._add_simple_state_to_cfg(sdfg, f"BodyIfEnd{name}") sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) From fdc6312fe56f2265ddfe4534c893916d02f6a6ab Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 23 Dec 2024 11:08:15 +0100 Subject: [PATCH 08/12] Add branch generation --- dace/frontend/fortran/ast_utils.py | 7 ++-- dace/frontend/fortran/fortran_parser.py | 47 ++++++++++--------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index 5487f2a891..853fabb36e 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -1,6 +1,6 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from typing import (Dict, Iterable, Iterator, List, Optional, +from typing import (Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union) from fparser.two.Fortran2003 import (Derived_Type_Def, Function_Body, @@ -106,6 +106,8 @@ class TaskletWriter: :return: python code for a tasklet, as a string """ + ast_elements: Dict[ast_internal_classes.FNode, Callable[..., str]] + def __init__(self, outputs: List[str], outputs_changes: List[str], @@ -163,7 +165,7 @@ def arrayconstructor2string(self, node: ast_internal_classes.Array_Constructor_N str_to_return += " ]" return str_to_return - def write_code(self, node: ast_internal_classes.FNode): + def write_code(self, node: ast_internal_classes.FNode) -> str: """ :param node: node to write code for :return: python code for the node, as a string @@ -172,7 +174,6 @@ def write_code(self, node: ast_internal_classes.FNode): :note If the node is not a string, it is checked if it is in the ast_elements dictionary :note If it is, the appropriate function is called with the node as an argument, leading to a recursive traversal of the tree spanned by the node :note If it not, an error is raised - """ self.depth += 1 if node.__class__ in self.ast_elements: diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 57bc475c20..bccaf960f5 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -42,7 +42,7 @@ from dace.frontend.fortran.ast_utils import children_of_type from dace.frontend.fortran.intrinsics import (IntrinsicSDFGTransformation, NeedsTypeInferenceException) from dace.properties import CodeBlock -from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowRegion, LoopRegion +from dace.sdfg.state import BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowRegion, LoopRegion global_struct_instance_counter = 0 @@ -295,7 +295,6 @@ def __init__(self, self.unallocated_arrays = [] self.all_array_names = [] self.last_sdfg_states = {} - self.last_returns = {} self.module_vars = [] self.sdfgs_count = 0 self.libraries = {} @@ -729,6 +728,7 @@ def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG, cfg print("Uh oh") # raise NotImplementedError("Fortran write statements are not implemented yet") + def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran if statements into a SDFG. @@ -736,37 +736,27 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: :param sdfg: The SDFG to which the node should be translated :param cfg: The control flow region into which the node should be translated """ + name = f"Conditional_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - name = f"If_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - begin_state = self._add_simple_state_to_cfg(sdfg, f"Begin{name}") - guard_substate = sdfg.add_state(f"Guard{name}") - sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) + cond_block = ConditionalBlock(name) + is_start = cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None + cfg.add_node(cond_block, is_start_block=is_start) + if not is_start: + cfg.add_edge(self.last_sdfg_states[cfg], cond_block, InterstateEdge()) + self.last_sdfg_states[cfg] = cond_block condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping, self.placeholders, self.placeholders_offsets, self.replace_names).write_code(node.cond) - body_ifstart_state = sdfg.add_state(f"BodyIfStart{name}") - self.last_sdfg_states[sdfg] = body_ifstart_state - self.translate(node.body, sdfg) - final_substate = sdfg.add_state(f"MergeState{name}") - - sdfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) - - if self.last_sdfg_states[sdfg] not in self.last_returns.get(sdfg): - body_ifend_state = self._add_simple_state_to_cfg(sdfg, f"BodyIfEnd{name}") - sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) + if_body = ControlFlowRegion(name + '_if_body') + cond_block.add_branch(CodeBlock(condition), if_body) + self.translate(node.body, sdfg, if_body) if len(node.body_else.execution) > 0: - name_else = f"Else_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - body_elsestart_state = sdfg.add_state("BodyElseStart" + name_else) - self.last_sdfg_states[sdfg] = body_elsestart_state - self.translate(node.body_else, sdfg) - body_elseend_state = self._add_simple_state_to_cfg(sdfg, f"BodyElseEnd{name_else}") - sdfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) - sdfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) - else: - sdfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) - self.last_sdfg_states[sdfg] = final_substate + else_body = ControlFlowRegion(name + '_else_body') + cond_block.add_branch(None, else_body) + self.translate(node.body_else, sdfg, else_body) + def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ @@ -775,7 +765,6 @@ def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG, :param sdfg: The SDFG to which the node should be translated :param cfg: The control flow region to which the node should be translated """ - name = "While_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) condition = ast_utils.ProcessedWriter(sdfg, @@ -795,6 +784,7 @@ def whilestmt2sdfg(self, node: ast_internal_classes.While_Stmt_Node, sdfg: SDFG, self.translate(node.body, sdfg, loop_region) + def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran for statements into a SDFG. @@ -802,7 +792,6 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg :param sdfg: The SDFG to which the node should be translated :param cfg: The control flow region to which the node should be translated """ - name = 'FOR_l_' + str(node.line_number[0]) + '_c_' + str(node.line_number[1]) decl_node = node.init init_expr = None @@ -890,7 +879,7 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, c datatype = self.get_dace_type(node.type) if node.name not in sdfg.symbols: sdfg.add_symbol(node.name, datatype) - if cfg not in self.last_cfg_states or self.last_sdfg_states[cfg] is None: + if cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None: bstate = cfg.add_state("SDFGbegin", is_start_state=True) self.last_sdfg_states[cfg] = bstate if node.init is not None: From e8b94b976ea22cc38ddb6b306707b802c589a159 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 23 Dec 2024 12:46:48 +0100 Subject: [PATCH 09/12] WIP --- dace/frontend/fortran/fortran_parser.py | 20 +++++++++++--------- dace/frontend/fortran/intrinsics.py | 2 ++ tests/fortran/fortran_test_helper.py | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index bccaf960f5..68a174c130 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -817,11 +817,12 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg increment_expr = 'i+0+1' if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment_expr = ast_utils.ProcessedWriter(sdfg, - self.name_mapping, - placeholders=self.placeholders, - placeholders_offsets=self.placeholders_offsets, - rename_dict=self.replace_names).write_code(node.iter.rval) + increment_rhs = ast_utils.ProcessedWriter(sdfg, + self.name_mapping, + placeholders=self.placeholders, + placeholders_offsets=self.placeholders_offsets, + rename_dict=self.replace_names).write_code(node.iter.rval) + increment_expr = f'{iter_name} = {increment_rhs}' loop_region = LoopRegion(name, condition, iter_name, init_expr, increment_expr, inverted=False, sdfg=sdfg) @@ -2106,17 +2107,17 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, self.transient_mode = True for j in node.specification_part.symbols: if isinstance(j, ast_internal_classes.Symbol_Decl_Node): - self.symbol2sdfg(j, new_sdfg) + self.symbol2sdfg(j, new_sdfg, new_sdfg) else: raise NotImplementedError("Symbol not implemented") for j in node.specification_part.specifications: - self.declstmt2sdfg(j, new_sdfg) + self.declstmt2sdfg(j, new_sdfg, new_sdfg) self.transient_mode = old_mode for i in assigns: - self.translate(i, new_sdfg) - self.translate(node.execution_part, new_sdfg) + self.translate(i, new_sdfg, new_sdfg) + self.translate(node.execution_part, new_sdfg, new_sdfg) if self.multiple_sdfgs == True: internal_sdfg.path = self.sdfg_path + new_sdfg.name + ".sdfg" @@ -2886,6 +2887,7 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr ast2sdfg.top_level = program ast2sdfg.globalsdfg = g ast2sdfg.translate(program, g, g) + g.reset_cfg_list() g.apply_transformations(IntrinsicSDFGTransformation) g.expand_library_nodes() gmap[ep] = g diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index 52a97e2b12..1cdbd84e69 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -1329,6 +1329,7 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ class IntrinsicSDFGTransformation(xf.SingleStateTransformation): + array1 = xf.PatternNode(nodes.AccessNode) array2 = xf.PatternNode(nodes.AccessNode) tasklet = xf.PatternNode(nodes.Tasklet) @@ -1409,6 +1410,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): class MathFunctions(IntrinsicTransformation): + MathTransformation = namedtuple("MathTransformation", "function return_type") MathReplacement = namedtuple("MathReplacement", "function replacement_function return_type") diff --git a/tests/fortran/fortran_test_helper.py b/tests/fortran/fortran_test_helper.py index 14ef321f46..6fa3048a34 100644 --- a/tests/fortran/fortran_test_helper.py +++ b/tests/fortran/fortran_test_helper.py @@ -10,6 +10,7 @@ from dace.frontend.fortran.ast_internal_classes import Name_Node from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast, SDFGConfig, \ create_sdfg_from_internal_ast +from dace.sdfg.sdfg import SDFG @dataclass @@ -276,7 +277,7 @@ def NAMED(cls, name: str): return cls(Name_Node, {'name': cls(has_value=name)}) -def create_singular_sdfg_from_string(sources: Dict[str, str], entry_point: str, normalize_offsets: bool = True): +def create_singular_sdfg_from_string(sources: Dict[str, str], entry_point: str, normalize_offsets: bool = True) -> SDFG: entry_point = entry_point.split('.') cfg = ParseConfig(main=sources['main.f90'], sources=sources, entry_points=tuple(entry_point)) From e657f2c4d7ce1fea9c69fe7ddd1631eb95857df8 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 23 Dec 2024 12:58:16 +0100 Subject: [PATCH 10/12] Fixes --- dace/frontend/fortran/fortran_parser.py | 4 ++-- tests/fortran/array_test.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 68a174c130..dd9571636c 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -405,7 +405,7 @@ def _add_simple_state_to_cfg(self, cfg: ControlFlowRegion, state_name: str): if cfg in self.last_sdfg_states and self.last_sdfg_states[cfg] is not None: substate = cfg.add_state(state_name) else: - substate = cfg.add_state(state_name, is_start_state=True) + substate = cfg.add_state(state_name, is_start_block=True) self._finish_add_state_to_cfg(cfg, substate) return substate @@ -881,7 +881,7 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, c if node.name not in sdfg.symbols: sdfg.add_symbol(node.name, datatype) if cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None: - bstate = cfg.add_state("SDFGbegin", is_start_state=True) + bstate = cfg.add_state("SDFGbegin", is_start_block=True) self.last_sdfg_states[cfg] = bstate if node.init is not None: substate = cfg.add_state(f"Dummystate_{node.name}") diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index e83b8367ba..5ea8fb1a52 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -5,6 +5,7 @@ from dace import dtypes, symbolic from dace.sdfg import utils as sdutil from dace.sdfg.nodes import AccessNode +from dace.sdfg.state import LoopRegion from tests.fortran.fortran_test_helper import SourceCodeBuilder, create_singular_sdfg_from_string @@ -163,9 +164,10 @@ def test_fortran_frontend_memlet_in_map_test(): """).check_with_gfortran().get() sdfg = create_singular_sdfg_from_string(sources, 'main') sdfg.simplify() - # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable - assert len(sdfg.out_edges(sdfg.start_state)) == 1 - iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) + # Expect that the start is the outer for loop + loop = sdfg.start_block + assert isinstance(loop, LoopRegion) + iter_var = symbolic.pystr_to_symbolic(loop.loop_variable) for state in sdfg.states(): if len(state.nodes()) > 1: From 2cbfcd54e5262ea789d22b64f5b5edeb431e5e3d Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 24 Dec 2024 10:37:11 +0100 Subject: [PATCH 11/12] Fix symbol and scalar conflict --- dace/frontend/fortran/fortran_parser.py | 23 +++++++++++++++++++++++ tests/fortran/array_attributes_test.py | 8 ++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index dd9571636c..109f0a9a1f 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -506,7 +506,30 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: Con for i in self.startpoint.specification_part.symbols: self.translate(i, sdfg, cfg) + # Sort the specifications to be translated in such a way that scalars are processed last. This means that + # where necessary they can be treated as symbols as opposed to scalars - for instance when they are used to + # describe the size of a data container. + scalar_vardecls = [] + non_scalar_vardecls = [] + other_specifications = [] for i in self.startpoint.specification_part.specifications: + if isinstance(i, ast_internal_classes.Decl_Stmt_Node): + has_array = False + for j in i.vardecl: + if isinstance(j, ast_internal_classes.Var_Decl_Node): + if j.sizes is not None: + has_array = True + if has_array: + non_scalar_vardecls.append(i) + else: + scalar_vardecls.append(i) + else: + other_specifications.append(i) + for i in non_scalar_vardecls: + self.translate(i, sdfg, cfg) + for i in other_specifications: + self.translate(i, sdfg, cfg) + for i in scalar_vardecls: self.translate(i, sdfg, cfg) for i in self.startpoint.specification_part.specifications: self._add_simple_state_to_cfg(cfg, "start_struct_size") diff --git a/tests/fortran/array_attributes_test.py b/tests/fortran/array_attributes_test.py index 6746b78c95..53c9fcaa67 100644 --- a/tests/fortran/array_attributes_test.py +++ b/tests/fortran/array_attributes_test.py @@ -308,10 +308,10 @@ def test_fortran_frontend_array_arbitrary_attribute2(): arrsize2=arrsize2, arrsize3=arrsize3, arrsize4=arrsize4) - assert a[1, 1] == arrsize - assert a[1, 2] == arrsize2 - assert a[1, 3] == arrsize3 - assert a[1, 4] == arrsize4 + assert a[0, 0] == arrsize + assert a[0, 1] == arrsize2 + assert a[0, 2] == arrsize3 + assert a[0, 3] == arrsize4 if __name__ == "__main__": From f6a9dc8d462f436706749e7664e794b605bcde09 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 24 Dec 2024 10:57:39 +0100 Subject: [PATCH 12/12] Import some of the offset changes --- dace/sdfg/propagation.py | 10 ++- .../transformation/interstate/sdfg_nesting.py | 80 ++++++++++++++----- .../fortran/non-interactive/function_test.py | 1 - tests/fortran/tasklet_test.py | 1 - 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 2983ec3c63..4ba80b4ea9 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1117,7 +1117,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState' if internal_memlet is None: continue try: - iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True) + ext_desc = parent_sdfg.arrays[iedge.data.data] + int_desc = sdfg.arrays[iedge.dst_conn] + iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True, internal_offset=int_desc.offset, + external_offset=ext_desc.offset) # If no appropriate memlet found, use array dimension for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[iedge.data.data].shape)): if rng[1] + 1 == s: @@ -1137,7 +1140,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState' if internal_memlet is None: continue try: - oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True) + ext_desc = parent_sdfg.arrays[oedge.data.data] + int_desc = sdfg.arrays[oedge.src_conn] + oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True, internal_offset=int_desc.offset, + external_offset=ext_desc.offset) # If no appropriate memlet found, use array dimension for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[oedge.data.data].shape)): if rng[1] + 1 == s: diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 31e751bb6a..3ea55e3cab 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -509,14 +509,24 @@ def apply(self, state: SDFGState, sdfg: SDFG): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e._data.get_dst_subset(e, state): - new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_dst_subset=True) + offset = sdfg.arrays[e.data.data].offset + new_memlet = helpers.unsqueeze_memlet(e.data, + outer_edge.data, + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) e._data.dst_subset = new_memlet.subset # NOTE: Node is source for edge in state.out_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e._data.get_src_subset(e, state): - new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_src_subset=True) + offset = sdfg.arrays[e.data.data].offset + new_memlet = helpers.unsqueeze_memlet(e.data, + outer_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) e._data.src_subset = new_memlet.subset # If source/sink node is not connected to a source/destination access @@ -625,10 +635,17 @@ def _modify_access_to_access(self, state.out_edges_by_connector(nsdfg_node, inner_data)) # Create memlet by unsqueezing both w.r.t. src and # dst subsets - in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True) + offset = state.parent.arrays[top_edge.data.data].offset + in_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) out_memlet = helpers.unsqueeze_memlet(inner_edge.data, matching_edge.data, - use_dst_subset=True) + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) new_memlet = in_memlet new_memlet.other_subset = out_memlet.subset @@ -651,10 +668,17 @@ def _modify_access_to_access(self, state.out_edges_by_connector(nsdfg_node, inner_data)) # Create memlet by unsqueezing both w.r.t. src and # dst subsets - in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True) + offset = state.parent.arrays[top_edge.data.data].offset + in_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) out_memlet = helpers.unsqueeze_memlet(inner_edge.data, matching_edge.data, - use_dst_subset=True) + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) new_memlet = in_memlet new_memlet.other_subset = out_memlet.subset @@ -689,7 +713,11 @@ def _modify_memlet_path( if inner_edge in edges_to_ignore: new_memlet = inner_edge.data else: - new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data) + offset = state.parent.arrays[top_edge.data.data].offset + new_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + internal_offset=offset, + external_offset=offset) if inputs: if inner_edge.dst in inner_to_outer: dst = inner_to_outer[inner_edge.dst] @@ -708,15 +736,19 @@ def _modify_memlet_path( mtree = state.memlet_tree(new_edge) # Modify all memlets going forward/backward - def traverse(mtree_node): + def traverse(mtree_node, state, nstate): result.add(mtree_node.edge) - mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, top_edge.data) + offset = state.parent.arrays[top_edge.data.data].offset + mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, + top_edge.data, + internal_offset=offset, + external_offset=offset) for child in mtree_node.children: - traverse(child) + traverse(child, state, nstate) result.add(new_edge) for child in mtree.children: - traverse(child) + traverse(child, state, nstate) return result @@ -1035,8 +1067,8 @@ def _check_cand(candidates, outer_edges): # If there are any symbols here that are not defined # in "defined_symbols" - missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), - list(indices)) - set(nsdfg.symbol_mapping.keys())) + missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - + set(nsdfg.symbol_mapping.keys())) if missing_symbols: ignore.add(cname) continue @@ -1045,10 +1077,13 @@ def _check_cand(candidates, outer_edges): _check_cand(out_candidates, state.out_edges_by_connector) # Return result, filtering out the states - return ({k: (dc(v), ind) - for k, (v, _, ind) in in_candidates.items() - if k not in ignore}, {k: (dc(v), ind) - for k, (v, _, ind) in out_candidates.items() if k not in ignore}) + return ({ + k: (dc(v), ind) + for k, (v, _, ind) in in_candidates.items() if k not in ignore + }, { + k: (dc(v), ind) + for k, (v, _, ind) in out_candidates.items() if k not in ignore + }) def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False): nsdfg = self.nsdfg @@ -1071,7 +1106,16 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], outer_edge = next(iter(outer_edges(nsdfg_node, aname))) except StopIteration: continue - new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data) + if isinstance(outer_edge.dst, nodes.NestedSDFG): + conn = outer_edge.dst_conn + else: + conn = outer_edge.src_conn + int_desc = nsdfg.arrays[conn] + ext_desc = sdfg.arrays[outer_edge.data.data] + new_memlet = helpers.unsqueeze_memlet(refine, + outer_edge.data, + internal_offset=int_desc.offset, + external_offset=ext_desc.offset) outer_edge.data.subset = subsets.Range([ ns if i in indices else os for i, (os, ns) in enumerate(zip(outer_edge.data.subset, new_memlet.subset)) diff --git a/tests/fortran/non-interactive/function_test.py b/tests/fortran/non-interactive/function_test.py index 87cfd260c3..c637de41ad 100644 --- a/tests/fortran/non-interactive/function_test.py +++ b/tests/fortran/non-interactive/function_test.py @@ -267,7 +267,6 @@ def test_fortran_frontend_function_test3(): sdfg.parent_nsdfg_node = None sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) - sdfg.view() sdfg.compile() diff --git a/tests/fortran/tasklet_test.py b/tests/fortran/tasklet_test.py index 49a2f5ac79..5c125f3e0f 100644 --- a/tests/fortran/tasklet_test.py +++ b/tests/fortran/tasklet_test.py @@ -32,7 +32,6 @@ def test_fortran_frontend_tasklet(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "tasklet", normalize_offsets=True) - sdfg.view() sdfg.simplify(verbose=True) sdfg.compile()