From 7e4bc3da7154164ca5b1264a86dbd1cf921e6260 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 12:16:30 +0100 Subject: [PATCH] Fix loop symbol type inference and loop to map --- dace/codegen/targets/framecode.py | 17 ++++++++++------- dace/transformation/interstate/loop_to_map.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index c0e08cfba7..11a198f119 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -936,13 +936,16 @@ def generate_code(self, if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: if not cfr.loop_variable in interstate_symbols: - l_end = loop_analysis.get_loop_end(cfr) - l_start = loop_analysis.get_init_assignment(cfr) - l_step = loop_analysis.get_loop_stride(cfr) - sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), - infer_expr_type(l_step, global_symbols), - infer_expr_type(l_end, global_symbols)) - interstate_symbols[cfr.loop_variable] = sym_type + if cfr.loop_variable in global_symbols: + interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable] + else: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 55327af5fb..9f487f561a 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -6,7 +6,8 @@ import sympy as sp from typing import Dict, List, Set -from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import graph as gr, nodes from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil @@ -94,6 +95,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): if start is None or end is None or step is None or itervar is None: return False + sset = {} + sset.update(sdfg.symbols) + sset.update(sdfg.arrays) + t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset)) + # We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules + # except for sequential would be invalid. + if not t in dtypes.INTEGER_TYPES: + return False + # Loops containing break, continue, or returns may not be turned into a map. for blk in self.loop.all_control_flow_blocks(): if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)):