Skip to content

Commit

Permalink
Merge branch 'cartesian/feature/absolute_k_indexation' into int_cast_…
Browse files Browse the repository at this point in the history
…aboslute_k_debug
  • Loading branch information
FlorianDeconinck committed Dec 24, 2024
2 parents ee6845a + 2ec187f commit 32a6c31
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 47 deletions.
6 changes: 5 additions & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,16 @@ class VariableKOffset(common.VariableKOffset[Expr]):
pass


class AbsoluteKIndex(common.AbsoluteKIndex[Expr]):
pass


class IndexAccess(common.FieldAccess, Expr):
offset: Optional[
Union[
common.CartesianOffset,
VariableKOffset,
common.AbsoluteKIndex,
AbsoluteKIndex,
Literal,
ScalarAccess, # For field index
]
Expand Down
67 changes: 35 additions & 32 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def visit_HorizontalRestriction(
def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs):
return dcir.VariableKOffset(k=self.visit(node.k, **kwargs))

def visit_AbsoluteKIndex(self, node: oir.AbsoluteKIndex, **kwargs):
return dcir.AbsoluteKIndex(k=self.visit(node.k, **kwargs))

def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl:
return dcir.LocalScalarDecl(name=node.name, dtype=node.dtype)

Expand Down Expand Up @@ -355,45 +358,33 @@ def visit_FieldAccess(
node.name in targets and node.offset == common.CartesianOffset.zero()
)
name = get_tasklet_symbol(node.name, node.offset, is_target=is_target)
if node.data_index:
if isinstance(node.offset, common.AbsoluteKIndex):
raise RuntimeError("Absolute K indexing cannot work with data index")
res = dcir.IndexAccess(
name=name,
offset=None,
data_index=node.data_index,
dtype=node.dtype,
)
elif node.name in absolute_K_access_fields:
if node.name in absolute_K_access_fields:
# Two cases:
# - we are accessing in absolute K - and need to resolve that index (offset.k)
# - we are NOT accessing in absolute K for this access, but the field will be
# before or after, and we need to revolve it as an IndexAccess rather than
# a scalar access
if isinstance(node.offset, common.AbsoluteKIndex):
offset = self.visit(
node.offset.k,
is_target=is_target,
targets=targets,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
absolute_K_access_fields=absolute_K_access_fields,
**kwargs,
)
else:
offset = self.visit(
node.offset,
is_target=is_target,
targets=targets,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
absolute_K_access_fields=absolute_K_access_fields,
**kwargs,
)
offset = self.visit(
node.offset,
is_target=is_target,
targets=targets,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
absolute_K_access_fields=absolute_K_access_fields,
**kwargs,
)
res = dcir.IndexAccess(
name=name,
offset=offset,
data_index=[],
data_index=node.data_index,
dtype=node.dtype,
)
elif node.data_index:
# No offset - but a data dimension
res = dcir.IndexAccess(
name=name,
offset=None,
data_index=node.data_index,
dtype=node.dtype,
)
else:
Expand Down Expand Up @@ -532,7 +523,19 @@ def visit_HorizontalExecution(
reshape_memlet = False
for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
if access_node.data_index and access_node.name == memlet.connector:
access_node.data_index = memlet_data_index + access_node.data_index
# Order matters!
# Seperate between case where K is offset or absolute and
# where it's a regular offset (should be in memlet_data_index)
if isinstance(
access_node.offset, (dcir.VariableKOffset, dcir.AbsoluteKIndex)
):
access_node.data_index = [
*memlet_data_index, # IJ - enforced by the fact offset is on K
access_node.offset.k, # Visitable K offset
*access_node.data_index, # Extra dims
]
else:
access_node.data_index = memlet_data_index + access_node.data_index
assert len(access_node.data_index) == array_ndims
reshape_memlet = True
if reshape_memlet:
Expand Down
56 changes: 42 additions & 14 deletions src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _visit_offset(
def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs):
return self._visit_offset(node, **kwargs)

def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs):
def visit_VariableKOffset(self, node: common.VariableKOffset, **kwargs):
return self._visit_offset(node, **kwargs)

def visit_AbsoluteKIndex(self, node: common.AbsoluteKIndex, **kwargs):
Expand Down Expand Up @@ -92,22 +92,50 @@ def visit_IndexAccess(
"Memlet connector and tasklet variable mismatch, DaCe IR error."
) from None

# Are we still on grid-point access (I.shape==1)
# or are we considering the array entirely (I.shape > 1)
memlet_accessed_as_full_array = memlet.access_info.shape[0] != 1
index_strs = []
if node.offset is not None:
index_strs.append(
self.visit(
node.offset,
decl=symtable[memlet.field],
access_info=memlet.access_info,
symtable=symtable,
in_idx=True,
**kwargs,
if memlet_accessed_as_full_array:
# Full array access with every dimensions accessed in full
# everything was packed in `data_index` in `DaCeIRBuilder.visit_HorizontalExecution`
# along the `reshape_memlet=True` code path
assert len(node.data_index) == len(sdfg_ctx.sdfg.arrays[memlet.field].shape)
assert len(node.data_index) > 3 # All the cartesian dims, and then some
# IJ resolve
index_strs = [
self.visit(node.data_index[0], in_idx=True, **kwargs),
self.visit(node.data_index[1], in_idx=True, **kwargs),
]
# K resolve (as a relative or absolute indexing)
rel_off = "__k+"
if isinstance(node.offset, dcir.AbsoluteKIndex):
rel_off = ""
index_strs.append(f"{rel_off}{self.visit(node.data_index[2], in_idx=True, **kwargs)}")
# Data dimensions (as absolute index)
index_strs.extend(
self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs)
for idx in node.data_index[3:]
)
else:
# Grid-point access, I & J are unitary, K can be offseted with variable
# Resolve K offset (also resolves I & J)
if node.offset is not None:
index_strs.append(
self.visit(
node.offset,
decl=symtable[memlet.field],
access_info=memlet.access_info,
symtable=symtable,
in_idx=True,
**kwargs,
)
)
# Add any data dimensions
index_strs.extend(
self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs)
for idx in node.data_index
)
index_strs.extend(
self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs)
for idx in node.data_index
)
return f"{node.name}[{','.join(index_strs)}]"

def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs):
Expand Down

0 comments on commit 32a6c31

Please sign in to comment.