Skip to content

Commit

Permalink
Fix for multiple_inputs_optimize_dag
Browse files Browse the repository at this point in the history
Fix for simple_optimize_dag
  • Loading branch information
tomwhite committed Sep 2, 2024
1 parent 381aa7d commit 5901054
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def can_fuse(n):
if dag.in_degree(op2) != 1:
return False

# if input is used by another node then don't fuse
# if input is one of the arrays being computed then don't fuse
op2_input = next(dag.predecessors(op2))
if array_names is not None and op2_input in array_names:
return False

# if input is used by another node then don't fuse
if dag.out_degree(op2_input) != 1:
return False

Expand Down Expand Up @@ -143,6 +147,7 @@ def can_fuse_predecessors(
dag,
name,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
Expand All @@ -164,6 +169,20 @@ def can_fuse_predecessors(
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
return False

# if a predecessor op produces one of the arrays being computed, then don't fuse
if array_names is not None:
predecessor_array_names = set(
array_name for _, array_name, _ in predecessor_ops_and_arrays(dag, name)
)
array_names_intersect = set(array_names) & predecessor_array_names
if len(array_names_intersect) > 0:
logger.debug(
"can't fuse %s since predecessor ops produce one or more arrays being computed %s",
name,
array_names_intersect,
)
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
Expand Down Expand Up @@ -217,6 +236,7 @@ def fuse_predecessors(
if not can_fuse_predecessors(
dag,
name,
array_names=array_names,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
always_fuse=always_fuse,
Expand Down

0 comments on commit 5901054

Please sign in to comment.