Skip to content

Commit

Permalink
fix: support multiple keys
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 2, 2023
1 parent dd3267a commit 5924332
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
11 changes: 6 additions & 5 deletions src/dask_awkward/lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from dask.base import unpack_collections
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardInputLayer

Expand Down Expand Up @@ -81,8 +80,9 @@ def report_necessary_buffers(

name_to_necessary_buffers: dict[str, NecessaryBuffers | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down Expand Up @@ -178,8 +178,9 @@ def report_necessary_columns(

name_to_necessary_columns: dict[str, frozenset | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
dsk = obj.__dask_graph__()
keys = obj.__dask_keys__()
projection_data = o._prepare_buffer_projection(dsk, keys)

# If the projection failed, or there are no input layers
if projection_data is None:
Expand Down
31 changes: 10 additions & 21 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import logging
import warnings
from collections.abc import Hashable, Iterable, Mapping
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast

import dask.config
Expand All @@ -17,6 +17,7 @@

if TYPE_CHECKING:
from awkward._nplikes.typetracer import TypeTracerReport
from dask.typing import Key

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,7 +66,7 @@ def all_optimizations(

def optimize(
dsk: HighLevelGraph,
keys: Hashable | list[Hashable] | set[Hashable],
keys: Sequence[Key],
**_: Any,
) -> Mapping:
"""Run optimizations specific to dask-awkward.
Expand All @@ -77,15 +78,15 @@ def optimize(
if dask.config.get("awkward.optimization.enabled"):
which = dask.config.get("awkward.optimization.which")
if "columns" in which:
dsk = optimize_columns(dsk)
dsk = optimize_columns(dsk, keys)
if "layer-chains" in which:
dsk = rewrite_layer_chains(dsk, keys)

return dsk


def _prepare_buffer_projection(
dsk: HighLevelGraph,
dsk: HighLevelGraph, keys: Sequence[Key]
) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None:
"""Pair layer names with lists of necessary columns."""
import awkward as ak
Expand Down Expand Up @@ -117,26 +118,14 @@ def _prepare_buffer_projection(

hlg = HighLevelGraph(projection_layers, dsk.dependencies)

# this loop builds up what are the possible final leaf nodes by
# inspecting the dependents dictionary. If something does not have
# a dependent, it must be the end of a graph. These are the things
# we need to compute for; we only use a single partition (the
# first). for a single collection `.compute()` this list will just
# be length 1; but if we are using `dask.compute` to pass in
# multiple collections to be computed simultaneously, this list
# will increase in length.
leaf_layers_keys = [
(k, 0) for k, v in dsk.dependents.items() if isinstance(v, set) and len(v) == 0
]

# now we try to compute for each possible output layer key (leaf
# node on partition 0); this will cause the typetacer reports to
# get correct fields/columns touched. If the result is a record or
# an array we of course want to touch all of the data/fields.
try:
for layer in hlg.layers.values():
layer.__dict__.pop("_cached_dict", None)
results = get_sync(hlg, leaf_layers_keys)
results = get_sync(hlg, list(keys))
for out in results:
if isinstance(out, (ak.Array, ak.Record)):
touch_data(out)
Expand All @@ -163,7 +152,7 @@ def _prepare_buffer_projection(
return layer_to_reports, layer_to_projection_state


def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Run column projection optimization.
This optimization determines which columns from an
Expand Down Expand Up @@ -192,7 +181,7 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
New, optimized task graph with column-projected ``AwkwardInputLayer``.
"""
projection_data = _prepare_buffer_projection(dsk)
projection_data = _prepare_buffer_projection(dsk, keys)
if projection_data is None:
return dsk

Expand Down Expand Up @@ -258,7 +247,7 @@ def _mock_output(layer):
return new_layer


def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
"""Smush chains of blockwise layers into a single layer.
The logic here identifies chains by popping layers (in arbitrary
Expand Down Expand Up @@ -292,7 +281,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
chains = []
deps = copy.copy(dsk.dependencies)

required_layers = {k[0] for k in keys}
required_layers = {k[0] for k in keys if isinstance(k, tuple)}
layers = {}
# find chains; each chain list is at least two keys long
dependents = dsk.dependents
Expand Down

0 comments on commit 5924332

Please sign in to comment.