From 18b5c7b9d1206132611ddc8d4c484f64c540c76a Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Tue, 12 Dec 2023 13:53:00 +0000 Subject: [PATCH] refactor: use walrus for readability --- src/dask_awkward/lib/optimize.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 58c5866b..5d6f4915 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -298,13 +298,14 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG current_layer_key = layer_key while ( len(children) == 1 - and dsk.dependencies[list(children)[0]] == {current_layer_key} - and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer) - and len(dsk.layers[current_layer_key]) == len(dsk.layers[list(children)[0]]) + and dsk.dependencies[first_child := next(iter(children))] + == {current_layer_key} + and isinstance(dsk.layers[first_child], AwkwardBlockwiseLayer) + and len(dsk.layers[current_layer_key]) == len(dsk.layers[first_child]) and current_layer_key not in required_layers ): # walk forwards - current_layer_key = list(children)[0] + current_layer_key = first_child chain.append(current_layer_key) all_layers.remove(current_layer_key) children = dependents[current_layer_key] @@ -312,13 +313,13 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG parents = dsk.dependencies[layer_key] while ( len(parents) == 1 - and dependents[list(parents)[0]] == {layer_key} - and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer) - and len(dsk.layers[layer_key]) == len(dsk.layers[list(parents)[0]]) - and list(parents)[0] not in required_layers + and dependents[first_parent := next(iter(parents))] == {layer_key} + and isinstance(dsk.layers[first_parent], AwkwardBlockwiseLayer) + and len(dsk.layers[layer_key]) == len(dsk.layers[first_parent]) + and next(iter(parents)) not in required_layers ): # walk backwards - layer_key = list(parents)[0] + layer_key = first_parent chain.insert(0, layer_key) all_layers.remove(layer_key) parents = dsk.dependencies[layer_key]