Skip to content

Commit

Permalink
skip whole function
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Aug 21, 2024
1 parent 51864e8 commit dbd4964
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import warnings
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, cast, no_type_check

import dask.config
from awkward.typetracer import touch_data
Expand Down Expand Up @@ -244,6 +244,7 @@ def _mock_output(layer):
return new_layer


@no_type_check
def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Smush chains of blockwise layers into a single layer.
Expand Down Expand Up @@ -336,8 +337,8 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
layer0 = cast(Blockwise, dsk.layers[chain[0]])
outlayer = layers[outkey]
numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0]
deps[outkey] = deps[chain[0]] # type: ignore
[deps.pop(ch) for ch in chain[:-1]] # type: ignore
deps[outkey] = deps[chain[0]]
[deps.pop(ch) for ch in chain[:-1]]

subgraph = layer0.dsk.copy() # mypy: ignore
indices = list(layer0.indices)
Expand All @@ -347,7 +348,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
for chain_member in chain[1:]:
layer = dsk.layers[chain_member]
for k in layer.io_deps: # mypy: ignore
outlayer.io_deps[k] = layer.io_deps[k] # type: ignore
outlayer.io_deps[k] = layer.io_deps[k]
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
Expand Down

0 comments on commit dbd4964

Please sign in to comment.