Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build[next]: switch dace version to main branch from git repo #1835

Merged
merged 36 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c4121a3
switch gt4py-next to dace main from git repo
edopao Jan 29, 2025
fd77ede
minor edit
edopao Jan 29, 2025
c451117
use dace fork to avoid submodule issue
edopao Jan 29, 2025
d6dc39a
enable lowering with control flow regions
edopao Jan 29, 2025
824f0cb
update uv lock
edopao Jan 29, 2025
7eeef16
update dace version
edopao Jan 29, 2025
dab9508
update uv-pre-commit version
edopao Jan 29, 2025
97841ee
switch to spcl repo
edopao Jan 30, 2025
b7357a7
rename extra
edopao Jan 30, 2025
d2c204e
minor edit
edopao Jan 30, 2025
1e279fd
Merge remote-tracking branch 'origin/main' into dace-main
edopao Jan 31, 2025
773a6b2
fix transformations
edopao Jan 31, 2025
15921bc
make if exclusive, always
edopao Jan 31, 2025
babdea5
Merge remote-tracking branch 'origin/main' into dace-main
edopao Feb 3, 2025
718c619
disable usage of ConditionalBlock
edopao Feb 3, 2025
cffdcd4
Revert "disable usage of ConditionalBlock"
edopao Feb 3, 2025
a35ad07
Updated the transformations to conform (more) with DaCe main.
philip-paul-mueller Feb 3, 2025
e2327e8
skip preprocess in gt_inline_nested_sdfg as temporary workaround
edopao Feb 3, 2025
c9391b5
Revert "skip preprocess in gt_inline_nested_sdfg as temporary workaro…
edopao Feb 3, 2025
059fae7
use fork of dace to fix PruneConnectors
edopao Feb 3, 2025
053bc4e
This should fix the map fusion fix.
philip-paul-mueller Feb 4, 2025
bb99900
Fixed another issue in this super outdate map fusion.
philip-paul-mueller Feb 4, 2025
4e5ae2c
fix pre-commit
edopao Feb 4, 2025
5a19e59
Fixed a new bug in the `is_accessed_downstream()` function.
philip-paul-mueller Feb 4, 2025
7e622f9
fix for pattern matching issue with LoopRegion
edopao Feb 4, 2025
5090388
re-enable check for exclusive if
edopao Feb 3, 2025
dabe459
set using_explicit_control_flow on scan nsdfg
edopao Feb 4, 2025
e2aa018
fix gpu_utils
edopao Feb 4, 2025
71fdeac
Restructured Edoardo's fix, it is now in multiple function.
philip-paul-mueller Feb 4, 2025
73d5b15
switch to dace main
edopao Feb 4, 2025
39a3049
Merge remote-tracking branch 'origin/main' into dace-main
edopao Feb 4, 2025
8774462
update uv lock
edopao Feb 4, 2025
1a513be
Updated the changes for the optimizer a bit.
philip-paul-mueller Feb 5, 2025
bc8d575
code comment on using_explicit_control_flow
edopao Feb 5, 2025
b46fd98
Update src/gt4py/next/program_processors/runners/dace/transformations…
philip-paul-mueller Feb 5, 2025
c2c933c
Update src/gt4py/next/program_processors/runners/dace/transformations…
philip-paul-mueller Feb 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:

- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.10
rev: 0.5.25
hooks:
- id: uv-lock

Expand Down
6 changes: 5 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
"internal": {"extras": [], "markers": ["not requires_dace"]},
"dace": {"extras": ["dace"], "markers": ["requires_dace"]},
}
# Use dace-next for GT4Py-next, to install a different dace version than in cartesian
CodeGenNextTestSettings = CodeGenTestSettings | {
"dace": {"extras": ["dace-next"], "markers": ["requires_dace"]},
}


# -- nox sessions --
Expand Down Expand Up @@ -158,7 +162,7 @@ def test_next(
) -> None:
"""Run selected 'gt4py.next' tests."""

codegen_settings = CodeGenTestSettings[codegen]
codegen_settings = CodeGenNextTestSettings[codegen]
device_settings = DeviceTestSettings[device]
groups: list[str] = ["test"]
mesh_markers: list[str] = []
Expand Down
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ all = ['gt4py[dace,formatting,jax,performance,testing]']
cuda11 = ['cupy-cuda11x>=12.0']
cuda12 = ['cupy-cuda12x>=12.0']
# features
dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4
dace = ['dace>=1.0.1,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4
dace-next = ['dace'] # pull dace latest version from the git repository
formatting = ['clang-format>=9.0']
jax = ['jax>=0.4.26']
jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]']
Expand Down Expand Up @@ -438,6 +439,14 @@ conflicts = [
{extra = 'jax-cuda12'},
{extra = 'rocm4_3'},
{extra = 'rocm5_0'}
],
[
{extra = 'dace'},
{extra = 'dace-next'}
],
[
{extra = 'all'},
{extra = 'dace-next'}
]
]

Expand All @@ -448,3 +457,4 @@ url = 'https://test.pypi.org/simple/'

[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = {git = "https://github.com/spcl/dace", branch = "main", extra = "dace-next"}
37 changes: 13 additions & 24 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class IteratorExpr:

field: dace.nodes.AccessNode
gt_dtype: ts.ListType | ts.ScalarType
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]]
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]]
indices: dict[gtx_common.Dimension, DataExpr]

def get_field_type(self) -> ts.FieldType:
Expand Down Expand Up @@ -767,9 +767,6 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp

assert len(node.args) == 3

# TODO(edopao): enable once supported in next DaCe release
use_conditional_block: Final[bool] = False

# evaluate the if-condition that will write to a boolean scalar node
condition_value = self.visit(node.args[0])
assert (
Expand All @@ -785,26 +782,18 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo)

# create states inside the nested SDFG for the if-branches
if use_conditional_block:
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body)

else:
entry_state = nsdfg.add_state("entry", is_start_block=True)
tstate = nsdfg.add_state("true_branch")
nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond"))
fstate = nsdfg.add_state("false_branch")
nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)"))
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body)

input_memlets: dict[str, MemletExpr | ValueExpr] = {}
nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def _lower_lambda_to_nested_sdfg(
# the lambda expression, i.e. body of the scan, will be created inside a nested SDFG.
nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan"))
nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo)
# We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`.
# This property is used by pattern matching in SDFG transformation framework
# to skip those transformations that do not yet support control flow blocks.
nsdfg.using_explicit_control_flow = True
edopao marked this conversation as resolved.
Show resolved Hide resolved
lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols)

# use the vertical dimension in the domain as scan dimension
Expand Down
4 changes: 0 additions & 4 deletions src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union

import dace
from dace.sdfg import utils as dace_sdfg_utils

from gt4py import eve
from gt4py.eve import concepts
Expand Down Expand Up @@ -999,9 +998,6 @@ def build_sdfg_from_gtir(
sdfg = sdfg_genenerator.visit(ir)
assert isinstance(sdfg, dace.SDFG)

# TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct
dace_sdfg_utils.inline_loop_blocks(sdfg)

if disable_field_origin_on_program_arguments:
_remove_field_origin_symbols(ir, sdfg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def gt_auto_optimize(
# For compatibility with DaCe (and until we found out why) the GT4Py
# auto optimizer will emulate this behaviour.
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
for edge in state.edges():
edge.data.wcr_nonatomic = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,62 +160,9 @@ def gt_gpu_transform_non_standard_memlet(
correct loop order.
- This function should be called after `gt_set_iteration_order()` has run.
"""
new_maps: set[dace_nodes.MapEntry] = set()

# This code is is copied from DaCe's code generator.
for e, state in list(sdfg.all_edges_recursive()):
nsdfg = state.parent
if (
isinstance(e.src, dace_nodes.AccessNode)
and isinstance(e.dst, dace_nodes.AccessNode)
and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global
and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global
):
a: dace_nodes.AccessNode = e.src
b: dace_nodes.AccessNode = e.dst

copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides(
None, nsdfg, state, e, a, b
)
dims = len(copy_shape)
if dims == 1:
continue
elif dims == 2:
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue
else:
continue
elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
continue

# For identifying the new map, we first store all neighbors of `a`.
old_neighbors_of_a: list[dace_nodes.AccessNode] = [
edge.dst for edge in state.out_edges(a)
]

# Turn unsupported copy to a map
try:
dace_transformation.dataflow.CopyToMap.apply_to(
nsdfg, save=False, annotate=False, a=a, b=b
)
except ValueError: # If transformation doesn't match, continue normally
continue

# We find the new map by comparing the new neighborhood of `a` with the old one.
new_nodes: set[dace_nodes.MapEntry] = {
edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a
}
assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes)
assert len(new_nodes) == 1
new_maps.update(new_nodes)
# Expand all non standard memlets and get the new MapEntries.
new_maps: set[dace_nodes.MapEntry] = _gt_expand_non_standard_memlets(sdfg)

# If there are no Memlets that are translated to copy-Maps, then we have nothing to do.
if len(new_maps) == 0:
Expand Down Expand Up @@ -283,6 +230,88 @@ def restrict_fusion_to_newly_created_maps(
return sdfg


def _gt_expand_non_standard_memlets(
sdfg: dace.SDFG,
) -> set[dace_nodes.MapEntry]:
"""Finds all non standard Memlet in the SDFG and expand them.

The function is used by `gt_gpu_transform_non_standard_memlet()` and performs
the actual expansion of the Memlet, i.e. turning all Memlets that can not be
expressed as a `memcpy()` into a Map, copy kernel.
The function will return the MapEntries of all expanded.

The function will process the SDFG recursively.
"""
new_maps: set[dace_nodes.MapEntry] = set()
for nsdfg in sdfg.all_sdfgs_recursive():
new_maps.update(_gt_expand_non_standard_memlets_sdfg(nsdfg))
return new_maps


def _gt_expand_non_standard_memlets_sdfg(
sdfg: dace.SDFG,
) -> set[dace_nodes.MapEntry]:
"""Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG."""
new_maps: set[dace_nodes.MapEntry] = set()
# The implementation is based on DaCe's code generator.
for state in sdfg.states():
for e in state.edges():
# We are only interested in edges that connects two access nodes of GPU memory.
if not (
isinstance(e.src, dace_nodes.AccessNode)
and isinstance(e.dst, dace_nodes.AccessNode)
and e.src.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global
and e.dst.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global
):
continue

a: dace_nodes.AccessNode = e.src
b: dace_nodes.AccessNode = e.dst
copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides(
None, sdfg, state, e, a, b
)
dims = len(copy_shape)
if dims == 1:
continue
elif dims == 2:
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue
else:
continue
elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
continue

# For identifying the new map, we first store all neighbors of `a`.
old_neighbors_of_a: list[dace_nodes.AccessNode] = [
edge.dst for edge in state.out_edges(a)
]

# Turn unsupported copy to a map
try:
dace_transformation.dataflow.CopyToMap.apply_to(
sdfg, save=False, annotate=False, a=a, b=b
)
except ValueError: # If transformation doesn't match, continue normally
continue

# We find the new map by comparing the new neighborhood of `a` with the old one.
new_nodes: set[dace_nodes.MapEntry] = {
edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a
}
assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes)
assert len(new_nodes) == 1
new_maps.update(new_nodes)
return new_maps


def gt_set_gpu_blocksize(
sdfg: dace.SDFG,
block_size: Optional[Sequence[int | str] | str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def gt_create_local_double_buffering(
it is not needed that the whole data is stored, but only the working set
of a single thread.
"""

processed_maps = 0
for nsdfg in sdfg.all_sdfgs_recursive():
processed_maps += _create_local_double_buffering_non_recursive(nsdfg)
Expand All @@ -60,6 +59,7 @@ def _create_local_double_buffering_non_recursive(

processed_maps = 0
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
scope_dict = state.scope_dict()
for node in state.nodes():
if not isinstance(node, dace_nodes.MapEntry):
Expand Down
Loading