Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass keys to column optimisation #430

Merged
merged 10 commits into from
Jan 18, 2024
11 changes: 6 additions & 5 deletions src/dask_awkward/lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from dask.base import unpack_collections
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardInputLayer

Expand Down Expand Up @@ -81,8 +80,9 @@ def report_necessary_buffers(

name_to_necessary_buffers: dict[str, NecessaryBuffers | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

agoose77 marked this conversation as resolved.
Show resolved Hide resolved
# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down Expand Up @@ -178,8 +178,9 @@ def report_necessary_columns(

name_to_necessary_columns: dict[str, frozenset | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down
101 changes: 44 additions & 57 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import logging
import warnings
from collections.abc import Hashable, Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast

import dask.config
Expand All @@ -14,9 +14,11 @@
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.utils import first

if TYPE_CHECKING:
from awkward._nplikes.typetracer import TypeTracerReport
from dask.typing import Key

log = logging.getLogger(__name__)

Expand All @@ -30,19 +32,13 @@
"""


def all_optimizations(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping:
"""Run all optimizations that benefit dask-awkward computations.

This function will run both dask-awkward specific and upstream
general optimizations from core dask.

"""
if not isinstance(keys, (list, set)):
keys = (keys,) # pragma: no cover
keys = tuple(flatten(keys))

if not isinstance(dsk, HighLevelGraph):
Expand All @@ -63,11 +59,7 @@ def all_optimizations(
return dsk


def optimize(
dsk: HighLevelGraph,
keys: Hashable | list[Hashable] | set[Hashable],
**_: Any,
) -> Mapping:
def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping:
"""Run optimizations specific to dask-awkward.

This is currently limited to determining the necessary columns for
Expand All @@ -77,15 +69,15 @@ def optimize(
if dask.config.get("awkward.optimization.enabled"):
which = dask.config.get("awkward.optimization.which")
if "columns" in which:
dsk = optimize_columns(dsk)
dsk = optimize_columns(dsk, keys)
if "layer-chains" in which:
dsk = rewrite_layer_chains(dsk, keys)

return dsk


def _prepare_buffer_projection(
dsk: HighLevelGraph,
dsk: HighLevelGraph, keys: Sequence[Key]
douglasdavis marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None:
"""Pair layer names with lists of necessary columns."""
import awkward as ak
Expand Down Expand Up @@ -117,17 +109,12 @@ def _prepare_buffer_projection(

hlg = HighLevelGraph(projection_layers, dsk.dependencies)

# this loop builds up what are the possible final leaf nodes by
# inspecting the dependents dictionary. If something does not have
# a dependent, it must be the end of a graph. These are the things
# we need to compute for; we only use a single partition (the
# first). for a single collection `.compute()` this list will just
# be length 1; but if we are using `dask.compute` to pass in
# multiple collections to be computed simultaneously, this list
# will increase in length.
leaf_layers_keys = [
(k, 0) for k, v in dsk.dependents.items() if isinstance(v, set) and len(v) == 0
]
minimal_keys: set[Key] = set()
for k in keys:
if isinstance(k, tuple) and len(k) == 2:
minimal_keys.add((k[0], 0))
else:
minimal_keys.add(k)

douglasdavis marked this conversation as resolved.
Show resolved Hide resolved
# now we try to compute for each possible output layer key (leaf
# node on partition 0); this will cause the typetacer reports to
Expand All @@ -136,7 +123,7 @@ def _prepare_buffer_projection(
try:
for layer in hlg.layers.values():
layer.__dict__.pop("_cached_dict", None)
results = get_sync(hlg, leaf_layers_keys)
results = get_sync(hlg, list(minimal_keys))
for out in results:
if isinstance(out, (ak.Array, ak.Record)):
touch_data(out)
Expand All @@ -163,7 +150,7 @@ def _prepare_buffer_projection(
return layer_to_reports, layer_to_projection_state


def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Run column projection optimization.

This optimization determines which columns from an
Expand Down Expand Up @@ -192,7 +179,7 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
New, optimized task graph with column-projected ``AwkwardInputLayer``.

"""
projection_data = _prepare_buffer_projection(dsk)
projection_data = _prepare_buffer_projection(dsk, keys)
if projection_data is None:
return dsk

Expand Down Expand Up @@ -258,7 +245,7 @@ def _mock_output(layer):
return new_layer


def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Smush chains of blockwise layers into a single layer.

The logic here identifies chains by popping layers (in arbitrary
Expand Down Expand Up @@ -292,54 +279,54 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
chains = []
deps = copy.copy(dsk.dependencies)

required_layers = {k[0] for k in keys}
required_layers = {k[0] for k in keys if isinstance(k, tuple)}
douglasdavis marked this conversation as resolved.
Show resolved Hide resolved
layers = {}
# find chains; each chain list is at least two keys long
dependents = dsk.dependents
all_layers = set(dsk.layers)
while all_layers:
lay = all_layers.pop()
val = dsk.layers[lay]
if not isinstance(val, AwkwardBlockwiseLayer):
layer_key = all_layers.pop()
layer = dsk.layers[layer_key]
if not isinstance(layer, AwkwardBlockwiseLayer):
# shortcut to avoid making comparisons
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged
continue
children = dependents[lay]
chain = [lay]
lay0 = lay
children = dependents[layer_key]
chain = [layer_key]
current_layer_key = layer_key
while (
len(children) == 1
and dsk.dependencies[list(children)[0]] == {lay}
and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]])
and lay not in required_layers
and dsk.dependencies[first(children)] == {current_layer_key}
and isinstance(dsk.layers[first(children)], AwkwardBlockwiseLayer)
and len(dsk.layers[current_layer_key]) == len(dsk.layers[first(children)])
and current_layer_key not in required_layers
):
# walk forwards
lay = list(children)[0]
chain.append(lay)
all_layers.remove(lay)
children = dependents[lay]
lay = lay0
parents = dsk.dependencies[lay]
current_layer_key = first(children)
chain.append(current_layer_key)
all_layers.remove(current_layer_key)
children = dependents[current_layer_key]

parents = dsk.dependencies[layer_key]
while (
len(parents) == 1
and dependents[list(parents)[0]] == {lay}
and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]])
and list(parents)[0] not in required_layers
and dependents[first(parents)] == {layer_key}
and isinstance(dsk.layers[first(parents)], AwkwardBlockwiseLayer)
and len(dsk.layers[layer_key]) == len(dsk.layers[first(parents)])
and next(iter(parents)) not in required_layers
):
# walk backwards
lay = list(parents)[0]
chain.insert(0, lay)
all_layers.remove(lay)
parents = dsk.dependencies[lay]
layer_key = first(parents)
chain.insert(0, layer_key)
all_layers.remove(layer_key)
parents = dsk.dependencies[layer_key]
if len(chain) > 1:
chains.append(chain)
layers[chain[-1]] = copy.copy(
dsk.layers[chain[-1]]
) # shallow copy to be mutated
else:
layers[lay] = val # passthrough unchanged
layers[layer_key] = layer # passthrough unchanged

# do rewrite
for chain in chains:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def input_layer_array_partition0(collection: Array) -> ak.Array:

"""
with dask.config.set({"awkward.optimization.which": ["columns"]}):
optimized_hlg = dak_optimize(collection.dask, [])
optimized_hlg = dak_optimize(collection.dask, collection.keys) # type: ignore
layers = list(optimized_hlg.layers) # type: ignore
layer_name = [name for name in layers if name.startswith("from-json")][0]
sgc, arg = optimized_hlg[(layer_name, 0)]
Expand Down
77 changes: 77 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import awkward as ak
import dask
import pytest

import dask_awkward as dak
from dask_awkward.lib.testutils import assert_eq
Expand Down Expand Up @@ -76,3 +77,79 @@ def test_multiple_computes_multiple_incapsulated(daa, caa):
(opt4_alone,) = dask.optimize(dstep4)
assert len(opt4_alone.dask.layers) == 1
assert_eq(opt4_alone, opt4)


def test_optimization_runs_on_multiple_collections_gh430(tmp_path_factory):
pytest.importorskip("pyarrow")
d = tmp_path_factory.mktemp("opt")
array1 = ak.Array(
[
{
"points": [
{"x": 3.0, "y": 4.0, "z": 1.0},
{"x": 2.0, "y": 5.0, "z": 2.0},
],
},
{
"points": [
{"x": 2.0, "y": 5.0, "z": 2.0},
{"x": 3.0, "y": 4.0, "z": 1.0},
],
},
{
"points": [],
},
{
"points": [
{"x": 2.0, "y": 6.0, "z": 2.0},
],
},
]
)
ak.to_parquet(array1, d / "p0.parquet", extensionarray=False)

array2 = ak.Array(
[
{
"points": [
{"x": 1.0, "y": 4.0, "z": 1.0},
],
},
{
"points": [
{"x": 7.0, "y": 5.0, "z": 2.0},
{"x": 3.0, "y": 4.0, "z": 1.0},
{"x": 5.0, "y": 5.0, "z": 2.0},
],
},
{
"points": [
{"x": 2.0, "y": 6.0, "z": 2.0},
],
},
{
"points": [],
},
]
)
ak.to_parquet(array2, d / "p1.parquet", extensionarray=False)

ds = dak.from_parquet(d)
a1 = ds.partitions[0].points
a2 = ds.partitions[1].points
a, b = ak.unzip(ak.cartesian([a1, a2], axis=1, nested=True))

def something(j, k):
return j.x + k.x

a_compute = something(a, b)
nc1 = dak.necessary_columns(a_compute)
assert sorted(list(nc1.items())[0][1]) == ["points.x"]

nc2 = dak.necessary_columns(a_compute, (a, b))
assert sorted(list(nc2.items())[0][1]) == ["points.x", "points.y", "points.z"]

x, (y, z) = dask.compute(a_compute, (a, b))
assert str(x)
assert str(y)
assert str(z)