Skip to content

Commit

Permalink
Ensure 1 task group per from_delayed (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jun 21, 2024
1 parent 2593829 commit 0bc9205
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 19 deletions.
3 changes: 1 addition & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,12 +767,11 @@ def isin(self, values):
# Numpy 1.23 supports creating arrays of iterables, while lower
# version 1.21.x and 1.22.x do not
pass
from dask_expr.io._delayed import _DelayedExpr

return new_collection(
expr.Isin(
self,
values=_DelayedExpr(
values=expr._DelayedExpr(
delayed(values, name="delayed-" + _tokenize_deterministic(values))
),
)
Expand Down
30 changes: 30 additions & 0 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,6 +2970,36 @@ def ndim(self):
return 0


class DelayedsExpr(Expr):
_parameters = []

def __init__(self, *delayed_objects):
self.operands = delayed_objects

def __str__(self):
return f"{type(self).__name__}({str(self.operands[0])})"

@property
def _name(self):
return "delayed-container-" + _tokenize_deterministic(*self.operands)

def _layer(self) -> dict:
dask = {}
for i, obj in enumerate(self.operands):
dc = obj.__dask_optimize__(obj.dask, obj.key).to_dict().copy()
dc[(self._name, i)] = dc[obj.key]
dc.pop(obj.key)
dask.update(dc)
return dask

def _divisions(self):
return (None,) * (len(self.operands) + 1)

@property
def ndim(self):
return 0


@normalize_token.register(Expr)
def normalize_expression(expr):
return expr._name
Expand Down
33 changes: 16 additions & 17 deletions dask_expr/io/_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.dataframe.utils import check_meta
from dask.delayed import Delayed, delayed

from dask_expr._expr import PartitionsFiltered, _DelayedExpr
from dask_expr._expr import DelayedsExpr, PartitionsFiltered
from dask_expr._util import _tokenize_deterministic
from dask_expr.io import BlockwiseIO

Expand All @@ -17,7 +17,14 @@


class FromDelayed(PartitionsFiltered, BlockwiseIO):
_parameters = ["meta", "user_divisions", "verify_meta", "_partitions", "prefix"]
_parameters = [
"delayed_container",
"meta",
"user_divisions",
"verify_meta",
"_partitions",
"prefix",
]
_defaults = {
"meta": None,
"_partitions": None,
Expand All @@ -32,35 +39,27 @@ def _name(self):
return super()._name
return self.prefix + "-" + _tokenize_deterministic(*self.operands)

def dependencies(self):
return self.dfs

@functools.cached_property
def dfs(self):
return self.operands[len(self._parameters) :]

@functools.cached_property
def _meta(self):
if self.operand("meta") is not None:
return self.operand("meta")

return delayed(make_meta)(self.dfs[0]).compute()
return delayed(make_meta)(self.delayed_container.operands[0]).compute()

def _divisions(self):
if self.operand("user_divisions") is not None:
return self.operand("user_divisions")
else:
return (None,) * (len(self.dfs) + 1)
return self.delayed_container.divisions

def _filtered_task(self, index: int):
key = self.dfs[index]._name
if self.verify_meta:
return (
functools.partial(check_meta, meta=self._meta, funcname="from_delayed"),
(key, 0),
(self.delayed_container._name, index),
)
else:
return identity, (key, 0)
return identity, (self.delayed_container._name, index)


def identity(x):
Expand Down Expand Up @@ -122,10 +121,10 @@ def from_delayed(
if not isinstance(item, Delayed):
raise TypeError("Expected Delayed object, got %s" % type(item).__name__)

dfs = [_DelayedExpr(df) for df in dfs]

from dask_expr._collection import new_collection

return new_collection(
FromDelayed(make_meta(meta), divisions, verify_meta, None, prefix, *dfs)
FromDelayed(
DelayedsExpr(*dfs), make_meta(meta), divisions, verify_meta, None, prefix
)
)
1 change: 1 addition & 0 deletions dask_expr/io/tests/test_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_from_delayed(prefix):

df = from_delayed(dfs, meta=pdf.head(0), divisions=None, prefix=prefix)
assert_eq(df, pdf)
assert len({k[0] for k in df.optimize().dask}) == 2
if prefix:
assert df._name.startswith(prefix)

Expand Down

0 comments on commit 0bc9205

Please sign in to comment.