Skip to content

Commit

Permalink
Add another test, and improve code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 1, 2024
1 parent 263d89c commit e54aa1b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 111 deletions.
10 changes: 6 additions & 4 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,14 @@ def can_fuse_multiple_primitive_ops(
if is_fuse_candidate(primitive_op) and all(
is_fuse_candidate(p) for p in predecessor_primitive_ops
):
# if the peak projected memory for running all the predecessor ops in order is
# larger than allowed_mem then we can't fuse
# If the peak projected memory for running all the predecessor ops in
# order is larger than allowed_mem then we can't fuse.
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
return False
# if the number of input blocks for each input is not uniform, then we can't fuse
# (this should never happen since all operations are currently uniform)
# If the number of input blocks for each input is not uniform, then we
# can't fuse. (This should never happen since all operations are
# currently uniform, and fused operations are too if fuse is applied in
# topological order.)
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
if not all(num_input_blocks[0] == n for n in num_input_blocks):
return False
Expand Down
194 changes: 87 additions & 107 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,10 @@ def test_fuse_unary_op(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
assert structurally_equivalent(
c.plan.optimize(optimize_function=opt_fn).dag,
expected_fused_dag,
)
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
assert get_num_input_blocks(
c.plan.optimize(optimize_function=opt_fn).dag, c.name
) == (1,)
assert get_num_input_blocks(optimized_dag, c.name) == (1,)

num_created_arrays = 2 # b, c
assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2
Expand Down Expand Up @@ -278,13 +274,10 @@ def test_fuse_binary_op(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (e,))
assert structurally_equivalent(
e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1)
assert get_num_input_blocks(
e.plan.optimize(optimize_function=opt_fn).dag, e.name
) == (1, 1)
assert get_num_input_blocks(optimized_dag, e.name) == (1, 1)

num_created_arrays = 3 # c, d, e
assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3
Expand Down Expand Up @@ -324,13 +317,10 @@ def test_fuse_unary_and_binary_op(spec):
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (), (c,))
add_placeholder_op(expected_fused_dag, (a, b, c), (f,))
assert structurally_equivalent(
f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = f.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(f.plan.dag, f.name) == (1, 1)
assert get_num_input_blocks(
f.plan.optimize(optimize_function=opt_fn).dag, f.name
) == (1, 1, 1)
assert get_num_input_blocks(optimized_dag, f.name) == (1, 1, 1)

result = f.compute(optimize_function=opt_fn)
assert_array_equal(result, np.ones((2, 2)))
Expand Down Expand Up @@ -361,13 +351,10 @@ def test_fuse_mixed_levels(spec):
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (), (c,))
add_placeholder_op(expected_fused_dag, (a, b, c), (e,))
assert structurally_equivalent(
e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1)
assert get_num_input_blocks(
e.plan.optimize(optimize_function=opt_fn).dag, e.name
) == (1, 1, 1)
assert get_num_input_blocks(optimized_dag, e.name) == (1, 1, 1)

result = e.compute(optimize_function=opt_fn)
assert_array_equal(result, 3 * np.ones((2, 2)))
Expand Down Expand Up @@ -395,13 +382,10 @@ def test_fuse_diamond(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a, a), (d,))
assert structurally_equivalent(
d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1)
assert get_num_input_blocks(
d.plan.optimize(optimize_function=opt_fn).dag, d.name
) == (1, 1)
assert get_num_input_blocks(optimized_dag, d.name) == (1, 1)

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 2 * np.ones((2, 2)))
Expand Down Expand Up @@ -433,13 +417,10 @@ def test_fuse_mixed_levels_and_diamond(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (d,))
assert structurally_equivalent(
d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1)
assert get_num_input_blocks(
d.plan.optimize(optimize_function=opt_fn).dag, d.name
) == (1, 1)
assert get_num_input_blocks(optimized_dag, d.name) == (1, 1)

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 2 * np.ones((2, 2)))
Expand Down Expand Up @@ -467,13 +448,10 @@ def test_fuse_repeated_argument(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a, a), (c,))
assert structurally_equivalent(
c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1)
assert get_num_input_blocks(
c.plan.optimize(optimize_function=opt_fn).dag, c.name
) == (1, 1)
assert get_num_input_blocks(optimized_dag, c.name) == (1, 1)

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -2 * np.ones((2, 2)))
Expand Down Expand Up @@ -506,13 +484,10 @@ def test_fuse_other_dependents(spec):
add_placeholder_op(expected_fused_dag, (a,), (c,))
add_placeholder_op(expected_fused_dag, (b,), (d,))
plan = arrays_to_plan(c, d)
assert structurally_equivalent(
plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
assert get_num_input_blocks(
c.plan.optimize(optimize_function=opt_fn).dag, c.name
) == (1,)
assert get_num_input_blocks(optimized_dag, c.name) == (1,)

c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn)
assert_array_equal(c_result, np.ones((2, 2)))
Expand Down Expand Up @@ -576,14 +551,10 @@ def stack_add(*a):
),
(j,),
)
assert structurally_equivalent(
j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = j.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(j.plan.dag, j.name) == (1,)
assert (
get_num_input_blocks(j.plan.optimize(optimize_function=opt_fn).dag, j.name)
== (1,) * 8
)
assert get_num_input_blocks(optimized_dag, j.name) == (1,) * 8

result = j.compute(optimize_function=opt_fn)
assert_array_equal(result, -8 * np.ones((2, 2)))
Expand Down Expand Up @@ -639,13 +610,10 @@ def test_fuse_large_fan_in_default(spec):
add_placeholder_op(expected_fused_dag, (a, b, c, d), (n,))
add_placeholder_op(expected_fused_dag, (e, f, g, h), (o,))
add_placeholder_op(expected_fused_dag, (n, o), (p,))
assert structurally_equivalent(
p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1)
assert get_num_input_blocks(
p.plan.optimize(optimize_function=opt_fn).dag, p.name
) == (1, 1)
assert get_num_input_blocks(optimized_dag, p.name) == (1, 1)

result = p.compute(optimize_function=opt_fn)
assert_array_equal(result, 8 * np.ones((2, 2)))
Expand Down Expand Up @@ -712,36 +680,30 @@ def test_fuse_large_fan_in_override(spec):
),
(p,),
)
assert structurally_equivalent(
p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1)
assert (
get_num_input_blocks(p.plan.optimize(optimize_function=opt_fn).dag, p.name)
== (1,) * 8
)
assert get_num_input_blocks(optimized_dag, p.name) == (1,) * 8

result = p.compute(optimize_function=opt_fn)
assert_array_equal(result, 8 * np.ones((2, 2)))

# now force everything to be fused with fuse_all_optimize_dag
# note that max_total_source_arrays is *not* set
opt_fn = fuse_all_optimize_dag

assert structurally_equivalent(
p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)

result = p.compute(optimize_function=opt_fn)
assert_array_equal(result, 8 * np.ones((2, 2)))


# merge chunks with same number of tasks
# merge chunks with same number of tasks (unary)
#
# a -> a
# | |
# | 3 | 3
# b c
# |
# | 1
# c
#
def test_fuse_with_merge_chunks_unary(spec):
Expand All @@ -757,26 +719,55 @@ def test_fuse_with_merge_chunks_unary(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
assert structurally_equivalent(
c.plan.optimize(optimize_function=opt_fn).dag,
expected_fused_dag,
)
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(b.plan.dag, b.name) == (3,)
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
assert get_num_input_blocks(
c.plan.optimize(optimize_function=opt_fn).dag, c.name
) == (3,)
assert get_num_input_blocks(optimized_dag, c.name) == (3,)

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((3, 2)))


# merge chunks with same number of tasks (binary)
#
# a b -> a b
# 3 | | 1 3 \ / 1
# c d e
# 1 \ / 1
# e
#
def test_fuse_with_merge_chunks_binary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.ones((3, 2), chunks=(3, 2), spec=spec)
c = merge_chunks_new(a, chunks=(3, 2))
d = xp.negative(b)
e = xp.add(c, d)

opt_fn = fuse_one_level(e)

e.visualize(optimize_function=opt_fn)

# check structure of optimized dag
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (e,))
optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1)
assert get_num_input_blocks(optimized_dag, e.name) == (3, 1)

result = e.compute(optimize_function=opt_fn)
assert_array_equal(result, np.zeros((3, 2)))


# merge chunks with different number of tasks (b has more tasks than c)
#
# a -> a
# | |
# | 1 | 3
# b c
# |
# | 3
# c
#
def test_fuse_merge_chunks_unary(spec):
Expand All @@ -794,15 +785,11 @@ def test_fuse_merge_chunks_unary(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
assert structurally_equivalent(
c.plan.optimize(optimize_function=opt_fn).dag,
expected_fused_dag,
)
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(b.plan.dag, b.name) == (1,)
assert get_num_input_blocks(c.plan.dag, c.name) == (3,)
assert get_num_input_blocks(
c.plan.optimize(optimize_function=opt_fn).dag, c.name
) == (3,)
assert get_num_input_blocks(optimized_dag, c.name) == (3,)

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((3, 2)))
Expand All @@ -811,9 +798,9 @@ def test_fuse_merge_chunks_unary(spec):
# merge chunks with different number of tasks (c has more tasks than d)
#
# a b -> a b
# \ / \ /
# 1 \ / 1 3 \ / 3
# c d
# |
# | 3
# d
#
def test_fuse_merge_chunks_binary(spec):
Expand All @@ -833,14 +820,11 @@ def test_fuse_merge_chunks_binary(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (d,))
assert structurally_equivalent(
d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag
)
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1)
assert get_num_input_blocks(d.plan.dag, d.name) == (3,)
assert get_num_input_blocks(
d.plan.optimize(optimize_function=opt_fn).dag, d.name
) == (3, 3)
assert get_num_input_blocks(optimized_dag, d.name) == (3, 3)

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 2 * np.ones((3, 2)))
Expand All @@ -864,14 +848,10 @@ def test_fuse_only_optimize_dag(spec):
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (b,))
add_placeholder_op(expected_fused_dag, (b,), (d,))
assert structurally_equivalent(
d.plan.optimize(optimize_function=opt_fn).dag,
expected_fused_dag,
)
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(d.plan.dag, d.name) == (1,)
assert get_num_input_blocks(
d.plan.optimize(optimize_function=opt_fn).dag, d.name
) == (1,)
assert get_num_input_blocks(optimized_dag, d.name) == (1,)

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((2, 2)))

0 comments on commit e54aa1b

Please sign in to comment.