Skip to content

Commit

Permalink
Fix loop symbol type inference and loop to map
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Nov 13, 2024
1 parent 8c488de commit 7e4bc3d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
17 changes: 10 additions & 7 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,13 +936,16 @@ def generate_code(self,

if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None:
if not cfr.loop_variable in interstate_symbols:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if cfr.loop_variable in global_symbols:
interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable]
else:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if not cfr.loop_variable in global_symbols:
global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable]

Expand Down
12 changes: 11 additions & 1 deletion dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import sympy as sp
from typing import Dict, List, Set

from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties
from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties
from dace.codegen.tools.type_inference import infer_expr_type
from dace.sdfg import graph as gr, nodes
from dace.sdfg import SDFG, SDFGState
from dace.sdfg import utils as sdutil
Expand Down Expand Up @@ -94,6 +95,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False):
if start is None or end is None or step is None or itervar is None:
return False

sset = {}
sset.update(sdfg.symbols)
sset.update(sdfg.arrays)
t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset))
# We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules
# except for sequential would be invalid.
if not t in dtypes.INTEGER_TYPES:
return False

# Loops containing break, continue, or returns may not be turned into a map.
for blk in self.loop.all_control_flow_blocks():
if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)):
Expand Down

0 comments on commit 7e4bc3d

Please sign in to comment.