diff --git a/test/test_transform.py b/test/test_transform.py index 8060d2038..4ad1971fb 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1704,6 +1704,35 @@ def test_precompute_lets_inner_length1_inames_live(): == parse("(e_inner + e_outer*16) / i_0")) +def test_duplicate_iname_not_read_only_nested(ctx_factory): + # See + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[i, j]: 0<=i,j<10}", + """ + for i + <> acc = 0 {id=init, tags=foo} + for j + acc = acc + A[i, j] * x[i, j] {id=update, tags=foo} + end + y[i] = acc {id=assign, tags=foo} + end + """, + [lp.GlobalArg("A,x,y", shape=lp.auto, dtype=np.float32), + ...], + ) + ref_t_unit = t_unit + + t_unit = lp.duplicate_inames( + t_unit, + inames="i", within="tag:foo", new_inames="irow") + print(t_unit) + assert (t_unit.default_entrypoint.id_to_insn["init"].within_inames + == frozenset({"irow"})) + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])