From 5606da8592d4beb5682dc50670373731d4e721ac Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 8 Mar 2024 09:01:35 +0000 Subject: [PATCH] Track source array names in PrimitiveOperation to fix a bug with argument ordering The bug was due to a mismatch between the order of the input source arrays in the DAG and the function argument order for the operation. Since the DAG order is not stable, the operation now records the names of the input source arrays in the source_array_names variable. --- cubed/core/ops.py | 2 ++ cubed/core/optimization.py | 29 ++++++++++++++---------- cubed/core/plan.py | 1 + cubed/primitive/blockwise.py | 10 +++++++++ cubed/primitive/rechunk.py | 22 ++++++++++++++++-- cubed/primitive/types.py | 5 ++++- cubed/tests/primitive/test_rechunk.py | 4 ++++ cubed/tests/test_optimization.py | 32 ++++++++++++++++++++++++++- 8 files changed, 89 insertions(+), 16 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 9c4e4c43..18a4d56e 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -739,6 +739,8 @@ def rechunk(x, chunks, target_store=None): temp_store = new_temp_path(name=name_int, spec=spec) ops = primitive_rechunk( x.zarray_maybe_lazy, + source_array_name=name, + int_array_name=name_int, target_chunks=target_chunks, allowed_mem=spec.allowed_mem, reserved_mem=spec.reserved_mem, diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index cfa261b9..7b024d83 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -81,17 +81,22 @@ def gensym(name="op"): return f"{name}-{sym_counter:03}" -def predecessors(dag, name): - """Return a node's predecessors, with repeats for multiple edges.""" +def predecessors_unordered(dag, name): + """Return a node's predecessors in no particular order, with repeats for multiple edges.""" for pre, _ in dag.in_edges(name): yield pre def predecessor_ops(dag, name): - """Return an op node's op predecessors""" - for input in predecessors(dag, name): - for pre in predecessors(dag, input): - yield pre + """Return an op node's op predecessors in the same order as the input source arrays for the op. + + Note that each input source array is produced by a single predecessor op. + """ + nodes = dict(dag.nodes(data=True)) + for input in nodes[name]["primitive_op"].source_array_names: + pre_list = list(predecessors_unordered(dag, input)) + assert len(pre_list) == 1 # each array is produced by a single op + yield pre_list[0] def is_fusable(node_dict): @@ -135,7 +140,7 @@ def can_fuse_predecessors( # the fused function would be more than an allowed maximum, then don't fuse if len(list(predecessor_ops(dag, name))) > 1: total_source_arrays = sum( - len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1 + len(list(predecessors_unordered(dag, pre))) if is_fusable(nodes[pre]) else 1 for pre in predecessor_ops(dag, name) ) if total_source_arrays > max_total_source_arrays: @@ -203,8 +208,8 @@ def fuse_predecessors( # re-wire dag to remove predecessor nodes that have been fused # 1. update edges to change inputs - for input in predecessors(dag, name): - pre = next(predecessors(dag, input)) + for input in predecessors_unordered(dag, name): + pre = next(predecessors_unordered(dag, input)) if not is_fusable(fused_nodes[pre]): # if a predecessor is not fusable then don't change the edge continue @@ -213,14 +218,14 @@ def fuse_predecessors( if not is_fusable(fused_nodes[pre]): # if a predecessor is not fusable then don't change the edge continue - for input in predecessors(dag, pre): + for input in predecessors_unordered(dag, pre): fused_dag.add_edge(input, name) # 2. remove predecessor nodes with no successors # (ones with successors are needed by other nodes) - for input in predecessors(dag, name): + for input in predecessors_unordered(dag, name): if fused_dag.out_degree(input) == 0: - for pre in list(predecessors(fused_dag, input)): + for pre in list(predecessors_unordered(fused_dag, input)): fused_dag.remove_node(pre) fused_dag.remove_node(input) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 07a1e85c..69ae70e4 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -462,6 +462,7 @@ def create_zarr_arrays(lazy_zarr_arrays, allowed_mem, reserved_mem): ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=[], target_array=None, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index e2354ba1..c80ee7c5 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -331,6 +331,7 @@ def general_blockwise( ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, @@ -483,6 +484,7 @@ def fused_func(*args): write_proxy, ) + source_array_names = primitive_op1.source_array_names target_array = primitive_op2.target_array projected_mem = max(primitive_op1.projected_mem, primitive_op2.projected_mem) allowed_mem = primitive_op2.allowed_mem @@ -497,6 +499,7 @@ def fused_func(*args): ) return PrimitiveOperation( pipeline=pipeline, + source_array_names=source_array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, @@ -607,6 +610,12 @@ def fused_func(*args): write_proxy, ) + source_array_names = [] + for i, p in enumerate(predecessor_primitive_ops): + if p is None: + source_array_names.append(primitive_op.source_array_names[i]) + else: + source_array_names.extend(p.source_array_names) target_array = primitive_op.target_array projected_mem = max( primitive_op.projected_mem, @@ -624,6 +633,7 @@ def fused_func(*args): ) return PrimitiveOperation( pipeline=fused_pipeline, + source_array_names=source_array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index 4bd5f44d..70ead917 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -22,6 +22,8 @@ def gensym(name: str) -> str: def rechunk( source: T_ZarrArray, + source_array_name: str, + int_array_name: str, target_chunks: T_RegularChunks, allowed_mem: int, reserved_mem: int, @@ -72,7 +74,13 @@ def rechunk( num_tasks = total_chunks(write_proxy.array.shape, write_proxy.chunks) return [ spec_to_primitive_op( - copy_spec, target, projected_mem, allowed_mem, reserved_mem, num_tasks + copy_spec, + source_array_name, + target, + projected_mem, + allowed_mem, + reserved_mem, + num_tasks, ) ] @@ -82,6 +90,7 @@ def rechunk( num_tasks = total_chunks(copy_spec1.write.array.shape, copy_spec1.write.chunks) op1 = spec_to_primitive_op( copy_spec1, + source_array_name, intermediate, projected_mem, allowed_mem, @@ -92,7 +101,13 @@ def rechunk( copy_spec2 = CubedCopySpec(int_proxy, write_proxy) num_tasks = total_chunks(copy_spec2.write.array.shape, copy_spec2.write.chunks) op2 = spec_to_primitive_op( - copy_spec2, target, projected_mem, allowed_mem, reserved_mem, num_tasks + copy_spec2, + int_array_name, + target, + projected_mem, + allowed_mem, + reserved_mem, + num_tasks, ) return [op1, op2] @@ -184,6 +199,7 @@ def copy_read_to_write(chunk_key: Sequence[slice], *, config: CubedCopySpec) -> def spec_to_primitive_op( spec: CubedCopySpec, + source_array_name: str, target_array: Any, projected_mem: int, allowed_mem: int, @@ -198,8 +214,10 @@ def spec_to_primitive_op( ChunkKeys(shape, spec.write.chunks), spec, ) + source_array_names = [source_array_name] return PrimitiveOperation( pipeline=pipeline, + source_array_names=source_array_names, target_array=target_array, projected_mem=projected_mem, allowed_mem=allowed_mem, diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index c9a3316c..0bf33f0e 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, List, Optional import zarr @@ -15,6 +15,9 @@ class PrimitiveOperation: pipeline: CubedPipeline """The pipeline that runs this operation.""" + source_array_names: List[str] + """The names of the arrays which are inputs to this operation.""" + target_array: Any """The array being computed by this operation.""" diff --git a/cubed/tests/primitive/test_rechunk.py b/cubed/tests/primitive/test_rechunk.py index 34d524e1..cf86cec2 100644 --- a/cubed/tests/primitive/test_rechunk.py +++ b/cubed/tests/primitive/test_rechunk.py @@ -65,6 +65,8 @@ def test_rechunk( ops = rechunk( source, + source_array_name="source-array", + int_array_name="int-array", target_chunks=target_chunks, allowed_mem=allowed_mem, reserved_mem=reserved_mem, @@ -109,6 +111,8 @@ def test_rechunk_allowed_mem_exceeded(tmp_path): ): rechunk( source, + source_array_name="source-array", + int_array_name="int-array", target_chunks=(4, 1), allowed_mem=allowed_mem, reserved_mem=0, diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 3f8e9fc7..5611e447 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -192,7 +192,8 @@ def add_placeholder_op(dag, inputs, outputs): def structurally_equivalent(dag1, dag2): # compare structure, and node labels for values but not operators since they are placeholders - # draw_dag(dag1) # uncomment for debugging + # draw_dag(dag1, "dag1") # uncomment for debugging + # draw_dag(dag2, "dag2") # uncomment for debugging labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label") labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label") @@ -209,6 +210,10 @@ def nm(node_attrs1, node_attrs2): def draw_dag(dag, name="dag"): + dag = dag.copy() + for _, d in dag.nodes(data=True): + if "name" in d: # pydot already has name + del d["name"] gv = nx.drawing.nx_pydot.to_pydot(dag) format = "svg" full_filename = f"{name}.{format}" @@ -429,6 +434,31 @@ def test_fuse_mixed_levels_and_diamond(spec): assert_array_equal(result, 2 * np.ones((2, 2))) +# derived from a bug found by array_api_tests/test_manipulation_functions.py::test_expand_dims +# a b -> a b +# \ / |\ /| +# c | d | +# /| | | | +# d | | e | +# | | \|/ +# e | f +# \| +# f +def test_fuse_mixed_levels_and_diamond_complex(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.add(a, b) + d = xp.positive(c) + e = d[1:, :] # operation can't be fused + f = xp.add(e, c) # this order exposed a bug in argument ordering + + opt_fn = multiple_inputs_optimize_dag + + f.visualize(optimize_function=opt_fn) + result = f.compute(optimize_function=opt_fn) + assert_array_equal(result, 4 * np.ones((2, 2))) + + # repeated argument # from https://github.com/cubed-dev/cubed/issues/65 #