Skip to content

Commit

Permalink
refactor: use clearer names
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored and douglasdavis committed Jan 18, 2024
1 parent 7c9ef44 commit 1661bbb
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,48 +287,48 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
dependents = dsk.dependents
all_layers = set(dsk.layers)
while all_layers:
lay = all_layers.pop()
val = dsk.layers[lay]
if not isinstance(val, AwkwardBlockwiseLayer):
layer_key = all_layers.pop()
layer = dsk.layers[layer_key]
if not isinstance(layer, AwkwardBlockwiseLayer):
# shortcut to avoid making comparisons
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged
continue
children = dependents[lay]
chain = [lay]
lay0 = lay
children = dependents[layer_key]
chain = [layer_key]
current_layer_key = layer_key
while (
len(children) == 1
and dsk.dependencies[list(children)[0]] == {lay}
and dsk.dependencies[list(children)[0]] == {current_layer_key}
and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]])
and lay not in required_layers
and len(dsk.layers[current_layer_key]) == len(dsk.layers[list(children)[0]])
and current_layer_key not in required_layers
):
# walk forwards
lay = list(children)[0]
chain.append(lay)
all_layers.remove(lay)
children = dependents[lay]
lay = lay0
parents = dsk.dependencies[lay]
current_layer_key = list(children)[0]
chain.append(current_layer_key)
all_layers.remove(current_layer_key)
children = dependents[current_layer_key]

parents = dsk.dependencies[layer_key]
while (
len(parents) == 1
and dependents[list(parents)[0]] == {lay}
and dependents[list(parents)[0]] == {layer_key}
and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]])
and len(dsk.layers[layer_key]) == len(dsk.layers[list(parents)[0]])
and list(parents)[0] not in required_layers
):
# walk backwards
lay = list(parents)[0]
chain.insert(0, lay)
all_layers.remove(lay)
parents = dsk.dependencies[lay]
layer_key = list(parents)[0]
chain.insert(0, layer_key)
all_layers.remove(layer_key)
parents = dsk.dependencies[layer_key]
if len(chain) > 1:
chains.append(chain)
layers[chain[-1]] = copy.copy(
dsk.layers[chain[-1]]
) # shallow copy to be mutated
else:
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged

# do rewrite
for chain in chains:
Expand Down

0 comments on commit 1661bbb

Please sign in to comment.