diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 9c4e4c43..f9053cc8 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -744,6 +744,8 @@ def rechunk(x, chunks, target_store=None): reserved_mem=spec.reserved_mem, target_store=target_store, temp_store=temp_store, + source_array_name=name, + int_array_name=name_int, ) from cubed.array_api import Array diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index cfa261b9..7b024d83 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -81,17 +81,22 @@ def gensym(name="op"): return f"{name}-{sym_counter:03}" -def predecessors(dag, name): - """Return a node's predecessors, with repeats for multiple edges.""" +def predecessors_unordered(dag, name): + """Return a node's predecessors in no particular order, with repeats for multiple edges.""" for pre, _ in dag.in_edges(name): yield pre def predecessor_ops(dag, name): - """Return an op node's op predecessors""" - for input in predecessors(dag, name): - for pre in predecessors(dag, input): - yield pre + """Return an op node's op predecessors in the same order as the input source arrays for the op. + + Note that each input source array is produced by a single predecessor op. + """ + nodes = dict(dag.nodes(data=True)) + for input in nodes[name]["primitive_op"].source_array_names: + pre_list = list(predecessors_unordered(dag, input)) + assert len(pre_list) == 1 # each array is produced by a single op + yield pre_list[0] def is_fusable(node_dict): @@ -135,7 +140,7 @@ def can_fuse_predecessors( # the fused function would be more than an allowed maximum, then don't fuse if len(list(predecessor_ops(dag, name))) > 1: total_source_arrays = sum( - len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1 + len(list(predecessors_unordered(dag, pre))) if is_fusable(nodes[pre]) else 1 for pre in predecessor_ops(dag, name) ) if total_source_arrays > max_total_source_arrays: @@ -203,8 +208,8 @@ def fuse_predecessors( # re-wire dag to remove predecessor nodes that have been fused # 1. update edges to change inputs - for input in predecessors(dag, name): - pre = next(predecessors(dag, input)) + for input in predecessors_unordered(dag, name): + pre = next(predecessors_unordered(dag, input)) if not is_fusable(fused_nodes[pre]): # if a predecessor is not fusable then don't change the edge continue @@ -213,14 +218,14 @@ def fuse_predecessors( if not is_fusable(fused_nodes[pre]): # if a predecessor is not fusable then don't change the edge continue - for input in predecessors(dag, pre): + for input in predecessors_unordered(dag, pre): fused_dag.add_edge(input, name) # 2. remove predecessor nodes with no successors # (ones with successors are needed by other nodes) - for input in predecessors(dag, name): + for input in predecessors_unordered(dag, name): if fused_dag.out_degree(input) == 0: - for pre in list(predecessors(fused_dag, input)): + for pre in list(predecessors_unordered(fused_dag, input)): fused_dag.remove_node(pre) fused_dag.remove_node(input) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 07a1e85c..502bf9b9 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -462,6 +462,7 @@ def create_zarr_arrays(lazy_zarr_arrays, allowed_mem, reserved_mem): ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=None, target_array=None, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index e2354ba1..c80ee7c5 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -331,6 +331,7 @@ def general_blockwise( ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, @@ -483,6 +484,7 @@ def fused_func(*args): write_proxy, ) + source_array_names = primitive_op1.source_array_names target_array = primitive_op2.target_array projected_mem = max(primitive_op1.projected_mem, primitive_op2.projected_mem) allowed_mem = primitive_op2.allowed_mem @@ -497,6 +499,7 @@ def fused_func(*args): ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=source_array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, @@ -607,6 +610,12 @@ def fused_func(*args): write_proxy, ) + source_array_names = [] + for i, p in enumerate(predecessor_primitive_ops): + if p is None: + source_array_names.append(primitive_op.source_array_names[i]) + else: + source_array_names.extend(p.source_array_names) target_array = primitive_op.target_array projected_mem = max( primitive_op.projected_mem, @@ -624,6 +633,7 @@ def fused_func(*args): ) return PrimitiveOperation( pipeline=fused_pipeline, + source_array_names=source_array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index 4bd5f44d..b22c97d2 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -27,6 +27,8 @@ def rechunk( reserved_mem: int, target_store: T_Store, temp_store: Optional[T_Store] = None, + source_array_name: Optional[str] = None, + int_array_name: Optional[str] = None, ) -> List[PrimitiveOperation]: """Change the chunking of an array, without changing its shape or dtype. @@ -72,7 +74,13 @@ def rechunk( num_tasks = total_chunks(write_proxy.array.shape, write_proxy.chunks) return [ spec_to_primitive_op( - copy_spec, target, projected_mem, allowed_mem, reserved_mem, num_tasks + copy_spec, + source_array_name, + target, + projected_mem, + allowed_mem, + reserved_mem, + num_tasks, ) ] @@ -82,6 +90,7 @@ def rechunk( num_tasks = total_chunks(copy_spec1.write.array.shape, copy_spec1.write.chunks) op1 = spec_to_primitive_op( copy_spec1, + source_array_name, intermediate, projected_mem, allowed_mem, @@ -92,7 +101,13 @@ def rechunk( copy_spec2 = CubedCopySpec(int_proxy, write_proxy) num_tasks = total_chunks(copy_spec2.write.array.shape, copy_spec2.write.chunks) op2 = spec_to_primitive_op( - copy_spec2, target, projected_mem, allowed_mem, reserved_mem, num_tasks + copy_spec2, + int_array_name, + target, + projected_mem, + allowed_mem, + reserved_mem, + num_tasks, ) return [op1, op2] @@ -184,6 +199,7 @@ def copy_read_to_write(chunk_key: Sequence[slice], *, config: CubedCopySpec) -> def spec_to_primitive_op( spec: CubedCopySpec, + source_array_name: Optional[str], target_array: Any, projected_mem: int, allowed_mem: int, @@ -200,6 +216,7 @@ def spec_to_primitive_op( ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=[source_array_name], target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index c9a3316c..b2e94827 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, List, Optional import zarr @@ -15,6 +15,9 @@ class PrimitiveOperation: pipeline: CubedPipeline """The pipeline that runs this operation.""" + source_array_names: Optional[List[str]] + """The names of the arrays which are inputs to this operation.""" + target_array: Any """The array being computed by this operation.""" diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 3f8e9fc7..5611e447 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -192,7 +192,8 @@ def add_placeholder_op(dag, inputs, outputs): def structurally_equivalent(dag1, dag2): # compare structure, and node labels for values but not operators since they are placeholders - # draw_dag(dag1) # uncomment for debugging + # draw_dag(dag1, "dag1") # uncomment for debugging + # draw_dag(dag2, "dag2") # uncomment for debugging labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label") labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label") @@ -209,6 +210,10 @@ def nm(node_attrs1, node_attrs2): def draw_dag(dag, name="dag"): + dag = dag.copy() + for _, d in dag.nodes(data=True): + if "name" in d: # pydot already has name + del d["name"] gv = nx.drawing.nx_pydot.to_pydot(dag) format = "svg" full_filename = f"{name}.{format}" @@ -429,6 +434,31 @@ def test_fuse_mixed_levels_and_diamond(spec): assert_array_equal(result, 2 * np.ones((2, 2))) +# derived from a bug found by array_api_tests/test_manipulation_functions.py::test_expand_dims +# a b -> a b +# \ / |\ /| +# c | d | +# /| | | | +# d | | e | +# | | \|/ +# e | f +# \| +# f +def test_fuse_mixed_levels_and_diamond_complex(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.add(a, b) + d = xp.positive(c) + e = d[1:, :] # operation can't be fused + f = xp.add(e, c) # this order exposed a bug in argument ordering + + opt_fn = multiple_inputs_optimize_dag + + f.visualize(optimize_function=opt_fn) + result = f.compute(optimize_function=opt_fn) + assert_array_equal(result, 4 * np.ones((2, 2))) + + # repeated argument # from https://github.com/cubed-dev/cubed/issues/65 #