diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index efa3e7d9..14238278 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -357,12 +357,14 @@ def can_fuse_multiple_primitive_ops( if is_fuse_candidate(primitive_op) and all( is_fuse_candidate(p) for p in predecessor_primitive_ops ): - # if the peak projected memory for running all the predecessor ops in order is - # larger than allowed_mem then we can't fuse + # If the peak projected memory for running all the predecessor ops in + # order is larger than allowed_mem then we can't fuse. if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem: return False - # if the number of input blocks for each input is not uniform, then we can't fuse - # (this should never happen since all operations are currently uniform) + # If the number of input blocks for each input is not uniform, then we + # can't fuse. (This should never happen since all operations are + # currently uniform, and fused operations are too if fuse is applied in + # topological order.) num_input_blocks = primitive_op.pipeline.config.num_input_blocks if not all(num_input_blocks[0] == n for n in num_input_blocks): return False diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index a9096b14..4cf69b01 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -233,14 +233,10 @@ def test_fuse_unary_op(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) num_created_arrays = 2 # b, c assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 @@ -278,13 +274,10 @@ def test_fuse_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) - assert get_num_input_blocks( - e.plan.optimize(optimize_function=opt_fn).dag, e.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1) num_created_arrays = 3 # c, d, e assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 @@ -324,13 +317,10 @@ def test_fuse_unary_and_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (f,)) - assert structurally_equivalent( - f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = f.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(f.plan.dag, f.name) == (1, 1) - assert get_num_input_blocks( - f.plan.optimize(optimize_function=opt_fn).dag, f.name - ) == (1, 1, 1) + assert get_num_input_blocks(optimized_dag, f.name) == (1, 1, 1) result = f.compute(optimize_function=opt_fn) assert_array_equal(result, np.ones((2, 2))) @@ -361,13 +351,10 @@ def test_fuse_mixed_levels(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) - assert get_num_input_blocks( - e.plan.optimize(optimize_function=opt_fn).dag, e.name - ) == (1, 1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1, 1) result = e.compute(optimize_function=opt_fn) assert_array_equal(result, 3 * np.ones((2, 2))) @@ -395,13 +382,10 @@ def test_fuse_diamond(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -433,13 +417,10 @@ def test_fuse_mixed_levels_and_diamond(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -467,13 +448,10 @@ def test_fuse_repeated_argument(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, c.name) == (1, 1) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -2 * np.ones((2, 2))) @@ -506,13 +484,10 @@ def test_fuse_other_dependents(spec): add_placeholder_op(expected_fused_dag, (a,), (c,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) plan = arrays_to_plan(c, d) - assert structurally_equivalent( - plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) assert_array_equal(c_result, np.ones((2, 2))) @@ -576,14 +551,10 @@ def stack_add(*a): ), (j,), ) - assert structurally_equivalent( - j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = j.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(j.plan.dag, j.name) == (1,) - assert ( - get_num_input_blocks(j.plan.optimize(optimize_function=opt_fn).dag, j.name) - == (1,) * 8 - ) + assert get_num_input_blocks(optimized_dag, j.name) == (1,) * 8 result = j.compute(optimize_function=opt_fn) assert_array_equal(result, -8 * np.ones((2, 2))) @@ -639,13 +610,10 @@ def test_fuse_large_fan_in_default(spec): add_placeholder_op(expected_fused_dag, (a, b, c, d), (n,)) add_placeholder_op(expected_fused_dag, (e, f, g, h), (o,)) add_placeholder_op(expected_fused_dag, (n, o), (p,)) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) - assert get_num_input_blocks( - p.plan.optimize(optimize_function=opt_fn).dag, p.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, p.name) == (1, 1) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -712,14 +680,10 @@ def test_fuse_large_fan_in_override(spec): ), (p,), ) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) - assert ( - get_num_input_blocks(p.plan.optimize(optimize_function=opt_fn).dag, p.name) - == (1,) * 8 - ) + assert get_num_input_blocks(optimized_dag, p.name) == (1,) * 8 result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -727,21 +691,19 @@ def test_fuse_large_fan_in_override(spec): # now force everything to be fused with fuse_all_optimize_dag # note that max_total_source_arrays is *not* set opt_fn = fuse_all_optimize_dag - - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) -# merge chunks with same number of tasks +# merge chunks with same number of tasks (unary) # # a -> a -# | | +# | 3 | 3 # b c -# | +# | 1 # c # def test_fuse_with_merge_chunks_unary(spec): @@ -757,26 +719,55 @@ def test_fuse_with_merge_chunks_unary(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(b.plan.dag, b.name) == (3,) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) +# merge chunks with same number of tasks (binary) +# +# a b -> a b +# 3 | | 1 3 \ / 1 +# c d e +# 1 \ / 1 +# e +# +def test_fuse_with_merge_chunks_binary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.ones((3, 2), chunks=(3, 2), spec=spec) + c = merge_chunks_new(a, chunks=(3, 2)) + d = xp.negative(b) + e = xp.add(c, d) + + opt_fn = fuse_one_level(e) + + e.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (e,)) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (3, 1) + + result = e.compute(optimize_function=opt_fn) + assert_array_equal(result, np.zeros((3, 2))) + + # merge chunks with different number of tasks (b has more tasks than c) # # a -> a -# | | +# | 1 | 3 # b c -# | +# | 3 # c # def test_fuse_merge_chunks_unary(spec): @@ -794,15 +785,11 @@ def test_fuse_merge_chunks_unary(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(b.plan.dag, b.name) == (1,) assert get_num_input_blocks(c.plan.dag, c.name) == (3,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -811,9 +798,9 @@ def test_fuse_merge_chunks_unary(spec): # merge chunks with different number of tasks (c has more tasks than d) # # a b -> a b -# \ / \ / +# 1 \ / 1 3 \ / 3 # c d -# | +# | 3 # d # def test_fuse_merge_chunks_binary(spec): @@ -833,14 +820,11 @@ def test_fuse_merge_chunks_binary(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) assert get_num_input_blocks(d.plan.dag, d.name) == (3,) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (3, 3) + assert get_num_input_blocks(optimized_dag, d.name) == (3, 3) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) @@ -864,14 +848,10 @@ def test_fuse_only_optimize_dag(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1,) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, d.name) == (1,) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((2, 2)))