Skip to content

Commit

Permalink
Fix pure reduce expansion for squeezed output memlets. (#1709)
Browse files Browse the repository at this point in the history
It was producing wrong indices for the initialization kernel, which
would not work for some simple valid SDFGs (see the demo in the test).
  • Loading branch information
pratyai authored Nov 6, 2024
1 parent 72ee732 commit 68dadd7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
2 changes: 1 addition & 1 deletion dace/libraries/standard/nodes/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
'reduce_init', {'_o%d' % i: '0:%s' % symstr(d)
for i, d in enumerate(outedge.data.subset.size())}, {},
'__out = %s' % node.identity,
{'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in range(output_dims)]))},
{'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in osqdim]))},
external_edges=True)
else:
nstate = nsdfg.add_state()
Expand Down
74 changes: 61 additions & 13 deletions tests/library/reduce_test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,82 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
import pytest

import dace
import dace.libraries.standard as std
from dace import SDFG, Memlet

C_in, C_out, H, K, N, W = (dace.symbol(s, dace.int64) for s in ('C_in', 'C_out', 'H', 'K', 'N', 'W'))


def make_sdfg():
g = SDFG('prog')
g.add_array('A', (N, 1, 1, C_in, C_out), dace.float32,
strides=(C_in * C_out, C_in * C_out, C_in * C_out, C_out, 1))
g.add_array('C', (N, H, W, C_out), dace.float32,
strides=(C_out * H * W, C_out * W, C_out, 1))

st0 = g.add_state('st0', is_start_block=True)
st = st0

A = st.add_access('A')
C = st.add_access('C')
R = st.add_reduce('lambda x, y: x + y', [1, 2, 3], 0)
st.add_nedge(A, R, Memlet(expr='A[0:N, 0, 0, 0:C_in, 0:C_out]'))
st.add_nedge(R, C, Memlet(expr='C[0:N, 5, 5, 0:C_out]'))

return g, R


def test_library_node_expand_reduce_pure():
n, cin, cout = 7, 7, 7
h, k, w = 25, 35, 45
A = np.ones((n, 1, 1, cin, cout), np.float32)

g, R = make_sdfg()
R.implementation = 'pure-seq'
g.validate()
g.compile()

wantC = np.ones((n, h, w, cout), np.float32) * 42
g(A=A, C=wantC, N=n, C_in=cin, C_out=cout, H=h, K=k, W=w)

g, R = make_sdfg()
R.implementation = 'pure'
g.validate()
g.compile()

gotC = np.ones((n, h, w, cout), np.float32) * 42
g(A=A, C=gotC, N=n, C_in=cin, C_out=cout, H=h, K=k, W=w)
assert np.allclose(wantC, gotC)


_params = ['pure', 'CUDA (device)', 'pure-seq', 'GPUAuto']



@pytest.mark.gpu
@pytest.mark.parametrize('impl', _params)
def test_multidim_gpu(impl):

test_cases = [([1, 64, 60, 60], (0, 2, 3), [64], np.float32),
([8, 512, 4096], (0,1), [4096], np.float32),
([8, 512, 4096], (0,1), [4096], np.float64),
([1024, 8], (0), [8], np.float32),
([111, 111, 111], (0,1), [111], np.float64),
([111, 111, 111], (1,2), [111], np.float64),
([1000000], (0), [1], np.float64),
([1111111], (0), [1], np.float64),
([123,21,26,8], (1,2), [123,8], np.float32),
([2, 512, 2], (0,2), [512], np.float32),
([512, 555, 257], (0,2), [555], np.float64)]
([8, 512, 4096], (0, 1), [4096], np.float32),
([8, 512, 4096], (0, 1), [4096], np.float64),
([1024, 8], (0), [8], np.float32),
([111, 111, 111], (0, 1), [111], np.float64),
([111, 111, 111], (1, 2), [111], np.float64),
([1000000], (0), [1], np.float64),
([1111111], (0), [1], np.float64),
([123, 21, 26, 8], (1, 2), [123, 8], np.float32),
([2, 512, 2], (0, 2), [512], np.float32),
([512, 555, 257], (0, 2), [555], np.float64)]

for in_shape, ax, out_shape, dtype in test_cases:
print(in_shape, ax, out_shape, dtype)
axes = ax

@dace.program
def multidimred(a, b):
b[:] = np.sum(a, axis=axes)

a = np.random.rand(*in_shape).astype(dtype)
b = np.random.rand(*out_shape).astype(dtype)
sdfg = multidimred.to_sdfg(a, b)
Expand All @@ -45,3 +92,4 @@ def multidimred(a, b):
if __name__ == '__main__':
for p in _params:
test_multidim_gpu(p)
test_library_node_expand_reduce_pure()

0 comments on commit 68dadd7

Please sign in to comment.