Skip to content

Commit

Permalink
try again
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Aug 20, 2024
1 parent 838ba31 commit a11da62
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import dask.config
from awkward.typetracer import touch_data
from dask.blockwise import fuse_roots, optimize_blockwise
from dask.blockwise import Blockwise, fuse_roots, optimize_blockwise
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync
Expand Down Expand Up @@ -333,7 +333,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
# outputs are the outputs of chain[-1]
# .dsk is composed from the .dsk of each layer
outkey = chain[-1]
layer0 = dsk.layers[chain[0]]
layer0 = cast(Blockwise, dsk.layers[chain[0]])
outlayer = layers[outkey]
numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0]
deps[outkey] = deps[chain[0]] # type: ignore
Expand All @@ -347,7 +347,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
for chain_member in chain[1:]:
layer = dsk.layers[chain_member]
for k in layer.io_deps:
outlayer.io_deps[k] = layer.io_deps[k]
outlayer.io_deps[k] = layer.io_deps[k] # type: ignore
func, *args = layer.dsk[chain_member]
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
Expand Down

0 comments on commit a11da62

Please sign in to comment.