Skip to content

Commit

Permalink
Merge branch 'main' into one-pass
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Jul 23, 2024
2 parents e29e929 + 8f3fb27 commit 63ddd47
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 122 deletions.
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- --target-version=py312

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.5.2
hooks:
- id: ruff

Expand All @@ -35,7 +35,7 @@ repos:
language_version: python3

- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
rev: v3.16.0
hooks:
- id: pyupgrade
args:
Expand All @@ -52,16 +52,17 @@ repos:
- id: yesqa

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.16.0
rev: 1.18.0
hooks:
- id: blacken-docs
additional_dependencies:
- black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.10.1
hooks:
- id: mypy
files: "src/"
args: [--ignore-missing-imports]
additional_dependencies:
- dask
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ src_paths = ["src", "tests"]

[tool.mypy]
python_version = "3.9"
files = ["src", "tests"]
files = ["src"]
exclude = ["tests/"]
strict = false
warn_unused_configs = true
show_error_codes = true
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
with_field,
with_name,
with_parameter,
without_field,
without_parameters,
zeros_like,
zip,
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
with_field,
with_name,
with_parameter,
without_field,
without_parameters,
zeros_like,
zip,
Expand Down
70 changes: 14 additions & 56 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
TypeTracerArray,
create_unknown_scalar,
is_unknown_scalar,
touch_data,
)
from dask.base import (
DaskMethodsMixin,
Expand All @@ -48,7 +47,6 @@

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
from dask_awkward.lib.optimize import all_optimizations
from dask_awkward.lib.utils import commit_to_reports
from dask_awkward.utils import (
DaskAwkwardNotImplemented,
IncompatiblePartitions,
Expand Down Expand Up @@ -400,10 +398,6 @@ def name(self) -> str:
def key(self) -> Key:
return (self._name, 0)

@property
def report(self):
return getattr(self._meta, "_report", set())

def _check_meta(self, m):
if isinstance(m, MaybeNone):
return ak.Array(m.content)
Expand Down Expand Up @@ -524,7 +518,6 @@ def f(self, other):
meta = op(self._meta, other._meta)
else:
meta = op(self._meta, other)
commit_to_reports(name, self.report)
return new_scalar_object(graph, name, meta=meta)

return f
Expand Down Expand Up @@ -720,9 +713,7 @@ def _check_meta(self, m: Any | None) -> Any | None:
def __getitem__(self, where):
token = tokenize(self, where)
new_name = f"{where}-{token}"
report = self.report
new_meta = self._meta[where]
commit_to_reports(new_name, report)

# first check for array type return
if isinstance(new_meta, ak.Array):
Expand All @@ -732,8 +723,6 @@ def __getitem__(self, where):
graphlayer,
dependencies=[self],
)
new_meta._report = report
hlg.layers[new_name].meta = new_meta
return new_array_object(hlg, new_name, meta=new_meta, npartitions=1)

# then check for scalar (or record) type
Expand All @@ -744,8 +733,6 @@ def __getitem__(self, where):
dependencies=[self],
)
if isinstance(new_meta, ak.Record):
new_meta._report = report
hlg.layers[new_name].meta = new_meta
return new_record_object(hlg, new_name, meta=new_meta)
else:
return new_scalar_object(hlg, new_name, meta=new_meta)
Expand Down Expand Up @@ -819,7 +806,7 @@ def new_record_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Record:
raise TypeError(
f"meta Record must have a typetracer backend, not {ak.backend(meta)}"
)
return out
return Record(dsk, name, meta)


def _is_numpy_or_cupy_like(arr: Any) -> bool:
Expand Down Expand Up @@ -950,10 +937,6 @@ def reset_meta(self) -> None:
"""Assign an empty typetracer array as the collection metadata."""
self._meta = empty_typetracer()

@property
def report(self):
return getattr(self._meta, "_report", set())

def repartition(
self,
npartitions: int | None = None,
Expand Down Expand Up @@ -989,7 +972,6 @@ def repartition(
new_graph = HighLevelGraph.from_collections(
key, new_layer, dependencies=(self,)
)
commit_to_reports(key, self.report)
return new_array_object(
new_graph,
key,
Expand Down Expand Up @@ -1175,13 +1157,11 @@ def _partitions(self, index: Any) -> Array:
name = f"partitions-{token}"
new_keys = self.keys_array[index].tolist()
dsk = {(name, i): tuple(key) for i, key in enumerate(new_keys)}
layer = AwkwardMaterializedLayer(dsk, previous_layer_names=[self.name])
graph = HighLevelGraph.from_collections(
name,
layer,
AwkwardMaterializedLayer(dsk, previous_layer_names=[self.name]),
dependencies=(self,),
)
layer.meta = self._meta

# if a single partition was requested we trivially know the new divisions.
if len(raw) == 1 and isinstance(raw[0], int) and self.known_divisions:
Expand All @@ -1193,7 +1173,7 @@ def _partitions(self, index: Any) -> Array:
# otherwise nullify the known divisions
else:
new_divisions = (None,) * (len(new_keys) + 1) # type: ignore
commit_to_reports(name, self.report)

return new_array_object(
graph, name, meta=self._meta, divisions=tuple(new_divisions)
)
Expand Down Expand Up @@ -1415,7 +1395,6 @@ def _getitem_slice_on_zero(self, where):
AwkwardMaterializedLayer(dask, previous_layer_names=[self.name]),
dependencies=[self],
)
commit_to_reports(name, self.report)
return new_array_object(
hlg,
name,
Expand Down Expand Up @@ -1526,14 +1505,9 @@ def __getitem__(self, where):
raise RuntimeError("Lists containing integers are not supported.")

if isinstance(where, tuple):
out = self._getitem_tuple(where)
else:
out = self._getitem_single(where)
if self.report:
commit_to_reports(out.name, self.report)
out._meta._report = self._meta._report
out.dask.layers[out.name].meta = out._meta
return out
return self._getitem_tuple(where)

return self._getitem_single(where)

def _is_method_heuristic(self, resolved: Any) -> bool:
return callable(resolved)
Expand Down Expand Up @@ -1860,12 +1834,10 @@ def partitionwise_layer(
"""
pairs: list[Any] = []
numblocks: dict[str, tuple[int, ...]] = {}
reps = set()
for arg in args:
if isinstance(arg, Array):
pairs.extend([arg.name, "i"])
numblocks[arg.name] = (arg.npartitions,)
reps.update(arg.report)
elif isinstance(arg, BlockwiseDep):
if len(arg.numblocks) == 1:
pairs.extend([arg, "i"])
Expand All @@ -1885,8 +1857,6 @@ def partitionwise_layer(
)
else:
pairs.extend([arg, None])
commit_to_reports(name, reps)

layer = dask_blockwise(
func,
name,
Expand Down Expand Up @@ -1970,23 +1940,8 @@ def _map_partitions(
**kwargs,
)

reps = set()
try:
if meta is None:
meta = map_meta(fn, *args, **kwargs)
else:
# To do any touching??
map_meta(fn, *args, **kwargs)
meta._report = reps
lay.meta = meta
except (AssertionError, TypeError, NotImplementedError):
[touch_data(_._meta) for _ in dak_arrays]

for dep in dak_arrays:
for rep in dep.report:
if rep not in reps:
rep.commit(name)
reps.add(rep)
if meta is None:
meta = map_meta(fn, *args, **kwargs)

hlg = HighLevelGraph.from_collections(
name,
Expand All @@ -2009,6 +1964,7 @@ def _map_partitions(
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
else:
new_divisions = in_divisions

if output_divisions is not None:
return new_array_object(
hlg,
Expand Down Expand Up @@ -2239,6 +2195,10 @@ def non_trivial_reduction(
if combiner is None:
combiner = reducer

# is_positional == True is not implemented
# if is_positional:
# assert combiner is reducer

# For `axis=None`, we prepare each array to have the following structure:
# [[[ ... [x1 x2 x3 ... xN] ... ]]] (length-1 outer lists)
# This makes the subsequent reductions an `axis=-1` reduction
Expand Down Expand Up @@ -2313,16 +2273,14 @@ def non_trivial_reduction(
)

graph = HighLevelGraph.from_collections(name_finalize, trl, dependencies=(chunked,))

meta = reducer(
array._meta,
axis=axis,
keepdims=keepdims,
mask_identity=mask_identity,
)
trl.meta = meta
commit_to_reports(name_finalize, array.report)
if isinstance(meta, ak.highlevel.Array):
meta._report = array.report
return new_array_object(graph, name_finalize, meta=meta, npartitions=1)
else:
return new_scalar_object(graph, name_finalize, meta=meta)
Expand Down
Loading

0 comments on commit 63ddd47

Please sign in to comment.