diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index e6885338..de989c8d 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -1,13 +1,19 @@ +import functools import operator +from itertools import product from typing import Union import dask.array as da import numpy as np +from dask import istask +from dask.array.core import slices_from_chunks from dask.base import DaskMethodsMixin, named_schedulers -from dask.utils import cached_cumsum, cached_property +from dask.core import flatten +from dask.utils import SerializableLock, cached_cumsum, cached_property, key_split from toolz import reduce from dask_expr import _core as core +from dask_expr._util import _tokenize_deterministic T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]] @@ -24,7 +30,13 @@ def __dask_postcompute__(self): return da.core.finalize, () def __dask_postpersist__(self): - return FromGraph, (self._meta, self.chunks, self._name) + state = self.lower_completely() + return FromGraph, ( + state._meta, + state.chunks, + list(flatten(state.__dask_keys__())), + key_split(state._name), + ) def compute(self, **kwargs): return DaskMethodsMixin.compute(self.simplify(), **kwargs) @@ -425,7 +437,7 @@ class IO(Array): class FromArray(IO): - _parameters = ["array", "chunks"] + _parameters = ["array", "chunks", "lock"] @property def chunks(self): @@ -438,9 +450,28 @@ def _meta(self): return self.array[tuple(slice(0, 0) for _ in range(self.array.ndim))] def _layer(self): - dsk = da.core.graph_from_arraylike( - self.array, chunks=self.chunks, shape=self.array.shape, name=self._name - ) + lock = self.operand("lock") + if lock is True: + lock = SerializableLock() + + is_ndarray = type(self.array) in (np.ndarray, np.ma.core.MaskedArray) + is_single_block = all(len(c) == 1 for c in self.chunks) + # Always use the getter for h5py etc. Not using isinstance(x, np.ndarray) + # because np.matrix is a subclass of np.ndarray. + if is_ndarray and not is_single_block and not lock: + # eagerly slice numpy arrays to prevent memory blowup + # GH5367, GH5601 + slices = slices_from_chunks(self.chunks) + keys = product([self._name], *(range(len(bds)) for bds in self.chunks)) + values = [self.array[slc] for slc in slices] + dsk = dict(zip(keys, values)) + elif is_ndarray and is_single_block: + # No slicing needed + dsk = {(self._name,) + (0,) * self.array.ndim: self.array} + else: + dsk = da.core.graph_from_arraylike( + self.array, chunks=self.chunks, shape=self.array.shape, name=self._name + ) return dict(dsk) # this comes as a legacy HLG for now def __str__(self): @@ -448,26 +479,36 @@ def __str__(self): class FromGraph(Array): - _parameters = ["layer", "_meta", "chunks", "_name"] + _parameters = ["layer", "_meta", "chunks", "keys", "name_prefix"] @property def _meta(self): return self.operand("_meta") - @property - def chunks(self): - return self.operand("chunks") - - @property + @functools.cached_property def _name(self): - return self.operand("_name") + return ( + self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + ) def _layer(self): - return dict(self.operand("layer")) + dsk = dict(self.operand("layer")) + # The name may not actually match the layers name therefore rewrite this + # using an alias + for k in self.operand("keys"): + if not isinstance(k, tuple): + raise TypeError(f"Expected tuple, got {type(k)}") + orig = dsk[k] + if not istask(orig): + del dsk[k] + dsk[(self._name, *k[1:])] = orig + else: + dsk[(self._name, *k[1:])] = k + return dsk -def from_array(x, chunks="auto"): - return FromArray(x, chunks) +def from_array(x, chunks="auto", lock=None): + return FromArray(x, chunks, lock=lock) from dask_expr.array.blockwise import Transpose, elemwise