Skip to content

Commit

Permalink
rebase stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jun 21, 2024
1 parent e399f40 commit 88b9d6c
Showing 1 changed file with 57 additions and 16 deletions.
73 changes: 57 additions & 16 deletions dask_expr/array/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import functools
import operator
from itertools import product
from typing import Union

import dask.array as da
import numpy as np
from dask import istask
from dask.array.core import slices_from_chunks
from dask.base import DaskMethodsMixin, named_schedulers
from dask.utils import cached_cumsum, cached_property
from dask.core import flatten
from dask.utils import SerializableLock, cached_cumsum, cached_property, key_split
from toolz import reduce

from dask_expr import _core as core
from dask_expr._util import _tokenize_deterministic

T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]]

Expand All @@ -24,7 +30,13 @@ def __dask_postcompute__(self):
return da.core.finalize, ()

def __dask_postpersist__(self):
return FromGraph, (self._meta, self.chunks, self._name)
state = self.lower_completely()
return FromGraph, (
state._meta,
state.chunks,
list(flatten(state.__dask_keys__())),
key_split(state._name),
)

def compute(self, **kwargs):
return DaskMethodsMixin.compute(self.simplify(), **kwargs)
Expand Down Expand Up @@ -425,7 +437,7 @@ class IO(Array):


class FromArray(IO):
_parameters = ["array", "chunks"]
_parameters = ["array", "chunks", "lock"]

@property
def chunks(self):
Expand All @@ -438,36 +450,65 @@ def _meta(self):
return self.array[tuple(slice(0, 0) for _ in range(self.array.ndim))]

def _layer(self):
dsk = da.core.graph_from_arraylike(
self.array, chunks=self.chunks, shape=self.array.shape, name=self._name
)
lock = self.operand("lock")
if lock is True:
lock = SerializableLock()

is_ndarray = type(self.array) in (np.ndarray, np.ma.core.MaskedArray)
is_single_block = all(len(c) == 1 for c in self.chunks)
# Always use the getter for h5py etc. Not using isinstance(x, np.ndarray)
# because np.matrix is a subclass of np.ndarray.
if is_ndarray and not is_single_block and not lock:
# eagerly slice numpy arrays to prevent memory blowup
# GH5367, GH5601
slices = slices_from_chunks(self.chunks)
keys = product([self._name], *(range(len(bds)) for bds in self.chunks))
values = [self.array[slc] for slc in slices]
dsk = dict(zip(keys, values))
elif is_ndarray and is_single_block:
# No slicing needed
dsk = {(self._name,) + (0,) * self.array.ndim: self.array}
else:
dsk = da.core.graph_from_arraylike(
self.array, chunks=self.chunks, shape=self.array.shape, name=self._name
)
return dict(dsk) # this comes as a legacy HLG for now

def __str__(self):
return "FromArray(...)"


class FromGraph(Array):
_parameters = ["layer", "_meta", "chunks", "_name"]
_parameters = ["layer", "_meta", "chunks", "keys", "name_prefix"]

@property
def _meta(self):
return self.operand("_meta")

@property
def chunks(self):
return self.operand("chunks")

@property
@functools.cached_property
def _name(self):
return self.operand("_name")
return (
self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands)
)

def _layer(self):
return dict(self.operand("layer"))
dsk = dict(self.operand("layer"))
# The name may not actually match the layers name therefore rewrite this
# using an alias
for k in self.operand("keys"):
if not isinstance(k, tuple):
raise TypeError(f"Expected tuple, got {type(k)}")
orig = dsk[k]
if not istask(orig):
del dsk[k]
dsk[(self._name, *k[1:])] = orig
else:
dsk[(self._name, *k[1:])] = k
return dsk


def from_array(x, chunks="auto"):
return FromArray(x, chunks)
def from_array(x, chunks="auto", lock=None):
return FromArray(x, chunks, lock=lock)


from dask_expr.array.blockwise import Transpose, elemwise

0 comments on commit 88b9d6c

Please sign in to comment.