diff --git a/tests/memlet_propagation_test.py b/tests/memlet_propagation_test.py index f90834cbb7..efd1cde9d1 100644 --- a/tests/memlet_propagation_test.py +++ b/tests/memlet_propagation_test.py @@ -103,9 +103,37 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]): raise RuntimeError('Expected subset of outer out memlet to be [0:M, 0:N], found ' + str(outer_out.subset)) +def test_nested_conditional_in_loop_in_map(): + N = dace.symbol('N') + M = dace.symbol('M') + + @dace.program + def nested_conditional_in_loop_in_map(A: dace.float64[M, N]): + for i in dace.map[0:M]: + for j in range(2, N, 1): + if A[0][0]: + A[i, j] = 1 + else: + A[i, j] = 2 + A[i, j] = A[i, j] * A[i, j] + + sdfg = nested_conditional_in_loop_in_map.to_sdfg(simplify=True) + + N = 20 + M = 20 + a_test = np.zeros((M, N), dtype=np.float64) + sdfg(a_test, M=M, N=N) + a_valid = np.zeros((M, N), dtype=np.float64) + for i in range(M): + for j in range(2, N, 1): + a_valid[i, j] = 4.0 + + assert np.allclose(a_test, a_valid) + if __name__ == '__main__': test_conditional() test_conditional_nested() test_runtime_conditional() test_nsdfg_memlet_propagation_with_one_sparse_dimension() + test_nested_conditional_in_loop_in_map()