Skip to content

Commit

Permalink
Merge pull request #475 from martindurant/stage_getitems
Browse files Browse the repository at this point in the history
feat: cache map_partitions ops
  • Loading branch information
lgray committed Mar 6, 2024
2 parents 0e54cef + 6e5a2ee commit 6a7fe2b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ classifiers = [
dependencies = [
"awkward >=2.5.1",
"dask >=2023.04.0",
"cachetools",
"typing_extensions >=4.8.0",
]
dynamic = ["version"]
Expand Down Expand Up @@ -128,7 +129,8 @@ warn_unreachable = true
"pyarrow.*",
"tlz.*",
"uproot.*",
"cloudpickle.*"
"cloudpickle.*",
"cachetools.*"
]
ignore_missing_imports = true

Expand Down
60 changes: 33 additions & 27 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload

import awkward as ak
import cachetools
import dask.config
import numpy as np
from awkward._do import remove_structure as ak_do_remove_structure
Expand Down Expand Up @@ -753,7 +754,7 @@ def __reduce__(self):
def fields(self) -> list[str]:
if self._meta is None:
raise TypeError("metadata is missing; cannot determine fields.")
return ak.fields(self._meta)
return getattr(self._meta, "fields", None) or []

@property
def layout(self) -> Any:
Expand Down Expand Up @@ -854,6 +855,9 @@ def _finalize_array(results: Sequence[Any]) -> Any:
raise RuntimeError(msg)


dak_cache = cachetools.LRUCache(maxsize=1000)


class Array(DaskMethodsMixin, NDArrayOperatorsMixin):
"""Partitioned, lazy, and parallel Awkward Array Dask collection.
Expand Down Expand Up @@ -1115,7 +1119,7 @@ def mask(self) -> AwkwardMask:
@property
def fields(self) -> list[str]:
"""Record field names (if any)."""
return ak.fields(self._meta)
return getattr(self._meta, "fields", None) or []

@property
def form(self) -> Form:
Expand Down Expand Up @@ -1917,45 +1921,50 @@ def _map_partitions(
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
token = token or tokenize(fn, *args, meta, **kwargs)
token = token or tokenize(fn, *args, output_divisions, **kwargs)
label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"

deps = [a for a in args if is_dask_collection(a)] + [
v for v in kwargs.values() if is_dask_collection(v)
]

dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))

lay = partitionwise_layer(
fn,
name,
*args,
**kwargs,
)
if name in dak_cache:
hlg, meta = dak_cache[name]
else:
lay = partitionwise_layer(
fn,
name,
*args,
**kwargs,
)

if meta is None:
meta = map_meta(fn, *args, **kwargs)
if meta is None:
meta = map_meta(fn, *args, **kwargs)

hlg = HighLevelGraph.from_collections(
name,
lay,
dependencies=deps,
)

if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
hlg = HighLevelGraph.from_collections(
name,
lay,
dependencies=deps,
)

if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
)
dak_cache[name] = hlg, meta
in_npartitions = dak_arrays[0].npartitions
in_divisions = dak_arrays[0].divisions

if output_divisions is not None:
if output_divisions == 1:
new_divisions = dak_arrays[0].divisions
else:
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
else:
new_divisions = in_divisions

if output_divisions is not None:
return new_array_object(
hlg,
name=name,
Expand Down Expand Up @@ -2050,10 +2059,7 @@ def map_partitions(
<Array [[1, 4, 9], [16], [5, 12, 21], [32]] type='4 * var * int64'>
This is effectively the same as `d = c * a`
"""
token = token or tokenize(base_fn, *args, meta, **kwargs)
label = hyphenize(label or funcname(base_fn))

opt_touch_all = kwargs.pop("opt_touch_all", None)
if opt_touch_all is not None:
Expand Down

0 comments on commit 6a7fe2b

Please sign in to comment.