From a11da62d0d6f834ff4c8a3bed1fc369484a11d86 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 20 Aug 2024 11:08:08 -0400 Subject: [PATCH] try again --- src/dask_awkward/lib/optimize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index d8691195..d6c1c08a 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -8,7 +8,7 @@ import dask.config from awkward.typetracer import touch_data -from dask.blockwise import fuse_roots, optimize_blockwise +from dask.blockwise import Blockwise, fuse_roots, optimize_blockwise from dask.core import flatten from dask.highlevelgraph import HighLevelGraph from dask.local import get_sync @@ -333,7 +333,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG # outputs are the outputs of chain[-1] # .dsk is composed from the .dsk of each layer outkey = chain[-1] - layer0 = dsk.layers[chain[0]] + layer0 = cast(Blockwise, dsk.layers[chain[0]]) outlayer = layers[outkey] numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] deps[outkey] = deps[chain[0]] # type: ignore @@ -347,7 +347,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG for chain_member in chain[1:]: layer = dsk.layers[chain_member] for k in layer.io_deps: - outlayer.io_deps[k] = layer.io_deps[k] + outlayer.io_deps[k] = layer.io_deps[k] # type: ignore func, *args = layer.dsk[chain_member] args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2)