Skip to content

Commit

Permalink
fix: pass keys to column optimisation (#430)
Browse files Browse the repository at this point in the history
make sure the `keys` argument propagates completely through optimization; if not we can miss some necessary buffers

---------

Co-authored-by: Doug Davis <[email protected]>
  • Loading branch information
agoose77 and douglasdavis committed Jan 18, 2024
1 parent 6adec1a commit 816eeba
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 63 deletions.
11 changes: 6 additions & 5 deletions src/dask_awkward/lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from dask.base import unpack_collections
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardInputLayer

Expand Down Expand Up @@ -81,8 +80,9 @@ def report_necessary_buffers(

name_to_necessary_buffers: dict[str, NecessaryBuffers | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down Expand Up @@ -178,8 +178,9 @@ def report_necessary_columns(

name_to_necessary_columns: dict[str, frozenset | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down
101 changes: 44 additions & 57 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import logging
import warnings
from collections.abc import Hashable, Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast

import dask.config
Expand All @@ -14,9 +14,11 @@
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.utils import first

if TYPE_CHECKING:
from awkward._nplikes.typetracer import TypeTracerReport
from dask.typing import Key

log = logging.getLogger(__name__)

Expand All @@ -30,19 +32,13 @@
"""


def all_optimizations(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping:
"""Run all optimizations that benefit dask-awkward computations.
This function will run both dask-awkward specific and upstream
general optimizations from core dask.
"""
if not isinstance(keys, (list, set)):
keys = (keys,) # pragma: no cover
keys = tuple(flatten(keys))

if not isinstance(dsk, HighLevelGraph):
Expand All @@ -63,11 +59,7 @@ def all_optimizations(
return dsk


def optimize(
dsk: HighLevelGraph,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping:
"""Run optimizations specific to dask-awkward.
This is currently limited to determining the necessary columns for
Expand All @@ -77,15 +69,15 @@ def optimize(
if dask.config.get("awkward.optimization.enabled"):
which = dask.config.get("awkward.optimization.which")
if "columns" in which:
dsk = optimize_columns(dsk)
dsk = optimize_columns(dsk, keys)
if "layer-chains" in which:
dsk = rewrite_layer_chains(dsk, keys)

return dsk


def _prepare_buffer_projection(
dsk: HighLevelGraph,
dsk: HighLevelGraph, keys: Sequence[Key]
) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None:
"""Pair layer names with lists of necessary columns."""
import awkward as ak
Expand Down Expand Up @@ -117,17 +109,12 @@ def _prepare_buffer_projection(

hlg = HighLevelGraph(projection_layers, dsk.dependencies)

# this loop builds up what are the possible final leaf nodes by
# inspecting the dependents dictionary. If something does not have
# a dependent, it must be the end of a graph. These are the things
# we need to compute for; we only use a single partition (the
# first). for a single collection `.compute()` this list will just
# be length 1; but if we are using `dask.compute` to pass in
# multiple collections to be computed simultaneously, this list
# will increase in length.
leaf_layers_keys = [
(k, 0) for k, v in dsk.dependents.items() if isinstance(v, set) and len(v) == 0
]
minimal_keys: set[Key] = set()
for k in keys:
if isinstance(k, tuple) and len(k) == 2:
minimal_keys.add((k[0], 0))
else:
minimal_keys.add(k)

# now we try to compute for each possible output layer key (leaf
# node on partition 0); this will cause the typetacer reports to
Expand All @@ -136,7 +123,7 @@ def _prepare_buffer_projection(
try:
for layer in hlg.layers.values():
layer.__dict__.pop("_cached_dict", None)
results = get_sync(hlg, leaf_layers_keys)
results = get_sync(hlg, list(minimal_keys))
for out in results:
if isinstance(out, (ak.Array, ak.Record)):
touch_data(out)
Expand All @@ -163,7 +150,7 @@ def _prepare_buffer_projection(
return layer_to_reports, layer_to_projection_state


def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Run column projection optimization.
This optimization determines which columns from an
Expand Down Expand Up @@ -192,7 +179,7 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
New, optimized task graph with column-projected ``AwkwardInputLayer``.
"""
projection_data = _prepare_buffer_projection(dsk)
projection_data = _prepare_buffer_projection(dsk, keys)
if projection_data is None:
return dsk

Expand Down Expand Up @@ -258,7 +245,7 @@ def _mock_output(layer):
return new_layer


def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Smush chains of blockwise layers into a single layer.
The logic here identifies chains by popping layers (in arbitrary
Expand Down Expand Up @@ -292,54 +279,54 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
chains = []
deps = copy.copy(dsk.dependencies)

required_layers = {k[0] for k in keys}
required_layers = {k[0] for k in keys if isinstance(k, tuple)}
layers = {}
# find chains; each chain list is at least two keys long
dependents = dsk.dependents
all_layers = set(dsk.layers)
while all_layers:
lay = all_layers.pop()
val = dsk.layers[lay]
if not isinstance(val, AwkwardBlockwiseLayer):
layer_key = all_layers.pop()
layer = dsk.layers[layer_key]
if not isinstance(layer, AwkwardBlockwiseLayer):
# shortcut to avoid making comparisons
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged
continue
children = dependents[lay]
chain = [lay]
lay0 = lay
children = dependents[layer_key]
chain = [layer_key]
current_layer_key = layer_key
while (
len(children) == 1
and dsk.dependencies[list(children)[0]] == {lay}
and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]])
and lay not in required_layers
and dsk.dependencies[first(children)] == {current_layer_key}
and isinstance(dsk.layers[first(children)], AwkwardBlockwiseLayer)
and len(dsk.layers[current_layer_key]) == len(dsk.layers[first(children)])
and current_layer_key not in required_layers
):
# walk forwards
lay = list(children)[0]
chain.append(lay)
all_layers.remove(lay)
children = dependents[lay]
lay = lay0
parents = dsk.dependencies[lay]
current_layer_key = first(children)
chain.append(current_layer_key)
all_layers.remove(current_layer_key)
children = dependents[current_layer_key]

parents = dsk.dependencies[layer_key]
while (
len(parents) == 1
and dependents[list(parents)[0]] == {lay}
and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]])
and list(parents)[0] not in required_layers
and dependents[first(parents)] == {layer_key}
and isinstance(dsk.layers[first(parents)], AwkwardBlockwiseLayer)
and len(dsk.layers[layer_key]) == len(dsk.layers[first(parents)])
and next(iter(parents)) not in required_layers
):
# walk backwards
lay = list(parents)[0]
chain.insert(0, lay)
all_layers.remove(lay)
parents = dsk.dependencies[lay]
layer_key = first(parents)
chain.insert(0, layer_key)
all_layers.remove(layer_key)
parents = dsk.dependencies[layer_key]
if len(chain) > 1:
chains.append(chain)
layers[chain[-1]] = copy.copy(
dsk.layers[chain[-1]]
) # shallow copy to be mutated
else:
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged

# do rewrite
for chain in chains:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def input_layer_array_partition0(collection: Array) -> ak.Array:
"""
with dask.config.set({"awkward.optimization.which": ["columns"]}):
optimized_hlg = dak_optimize(collection.dask, [])
optimized_hlg = dak_optimize(collection.dask, collection.keys) # type: ignore
layers = list(optimized_hlg.layers) # type: ignore
layer_name = [name for name in layers if name.startswith("from-json")][0]
sgc, arg = optimized_hlg[(layer_name, 0)]
Expand Down
77 changes: 77 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import awkward as ak
import dask
import pytest

import dask_awkward as dak
from dask_awkward.lib.testutils import assert_eq
Expand Down Expand Up @@ -76,3 +77,79 @@ def test_multiple_computes_multiple_incapsulated(daa, caa):
(opt4_alone,) = dask.optimize(dstep4)
assert len(opt4_alone.dask.layers) == 1
assert_eq(opt4_alone, opt4)


def test_optimization_runs_on_multiple_collections_gh430(tmp_path_factory):
pytest.importorskip("pyarrow")
d = tmp_path_factory.mktemp("opt")
array1 = ak.Array(
[
{
"points": [
{"x": 3.0, "y": 4.0, "z": 1.0},
{"x": 2.0, "y": 5.0, "z": 2.0},
],
},
{
"points": [
{"x": 2.0, "y": 5.0, "z": 2.0},
{"x": 3.0, "y": 4.0, "z": 1.0},
],
},
{
"points": [],
},
{
"points": [
{"x": 2.0, "y": 6.0, "z": 2.0},
],
},
]
)
ak.to_parquet(array1, d / "p0.parquet", extensionarray=False)

array2 = ak.Array(
[
{
"points": [
{"x": 1.0, "y": 4.0, "z": 1.0},
],
},
{
"points": [
{"x": 7.0, "y": 5.0, "z": 2.0},
{"x": 3.0, "y": 4.0, "z": 1.0},
{"x": 5.0, "y": 5.0, "z": 2.0},
],
},
{
"points": [
{"x": 2.0, "y": 6.0, "z": 2.0},
],
},
{
"points": [],
},
]
)
ak.to_parquet(array2, d / "p1.parquet", extensionarray=False)

ds = dak.from_parquet(d)
a1 = ds.partitions[0].points
a2 = ds.partitions[1].points
a, b = ak.unzip(ak.cartesian([a1, a2], axis=1, nested=True))

def something(j, k):
return j.x + k.x

a_compute = something(a, b)
nc1 = dak.necessary_columns(a_compute)
assert sorted(list(nc1.items())[0][1]) == ["points.x"]

nc2 = dak.necessary_columns(a_compute, (a, b))
assert sorted(list(nc2.items())[0][1]) == ["points.x", "points.y", "points.z"]

x, (y, z) = dask.compute(a_compute, (a, b))
assert str(x)
assert str(y)
assert str(z)

0 comments on commit 816eeba

Please sign in to comment.