From 0bc9205ed949dcbff593c7e2338c209821174d14 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:26:14 +0200 Subject: [PATCH] Ensure 1 task group per from_delayed (#1084) --- dask_expr/_collection.py | 3 +-- dask_expr/_expr.py | 30 +++++++++++++++++++++++++++ dask_expr/io/_delayed.py | 33 +++++++++++++++--------------- dask_expr/io/tests/test_delayed.py | 1 + 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 0383519b..bf1dec73 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -767,12 +767,11 @@ def isin(self, values): # Numpy 1.23 supports creating arrays of iterables, while lower # version 1.21.x and 1.22.x do not pass - from dask_expr.io._delayed import _DelayedExpr return new_collection( expr.Isin( self, - values=_DelayedExpr( + values=expr._DelayedExpr( delayed(values, name="delayed-" + _tokenize_deterministic(values)) ), ) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 318ec587..c0366432 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2970,6 +2970,36 @@ def ndim(self): return 0 +class DelayedsExpr(Expr): + _parameters = [] + + def __init__(self, *delayed_objects): + self.operands = delayed_objects + + def __str__(self): + return f"{type(self).__name__}({str(self.operands[0])})" + + @property + def _name(self): + return "delayed-container-" + _tokenize_deterministic(*self.operands) + + def _layer(self) -> dict: + dask = {} + for i, obj in enumerate(self.operands): + dc = obj.__dask_optimize__(obj.dask, obj.key).to_dict().copy() + dc[(self._name, i)] = dc[obj.key] + dc.pop(obj.key) + dask.update(dc) + return dask + + def _divisions(self): + return (None,) * (len(self.operands) + 1) + + @property + def ndim(self): + return 0 + + @normalize_token.register(Expr) def normalize_expression(expr): return expr._name diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index fe955a97..2ed80f11 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -8,7 +8,7 @@ from dask.dataframe.utils import check_meta from dask.delayed import Delayed, delayed -from dask_expr._expr import PartitionsFiltered, _DelayedExpr +from dask_expr._expr import DelayedsExpr, PartitionsFiltered from dask_expr._util import _tokenize_deterministic from dask_expr.io import BlockwiseIO @@ -17,7 +17,14 @@ class FromDelayed(PartitionsFiltered, BlockwiseIO): - _parameters = ["meta", "user_divisions", "verify_meta", "_partitions", "prefix"] + _parameters = [ + "delayed_container", + "meta", + "user_divisions", + "verify_meta", + "_partitions", + "prefix", + ] _defaults = { "meta": None, "_partitions": None, @@ -32,35 +39,27 @@ def _name(self): return super()._name return self.prefix + "-" + _tokenize_deterministic(*self.operands) - def dependencies(self): - return self.dfs - - @functools.cached_property - def dfs(self): - return self.operands[len(self._parameters) :] - @functools.cached_property def _meta(self): if self.operand("meta") is not None: return self.operand("meta") - return delayed(make_meta)(self.dfs[0]).compute() + return delayed(make_meta)(self.delayed_container.operands[0]).compute() def _divisions(self): if self.operand("user_divisions") is not None: return self.operand("user_divisions") else: - return (None,) * (len(self.dfs) + 1) + return self.delayed_container.divisions def _filtered_task(self, index: int): - key = self.dfs[index]._name if self.verify_meta: return ( functools.partial(check_meta, meta=self._meta, funcname="from_delayed"), - (key, 0), + (self.delayed_container._name, index), ) else: - return identity, (key, 0) + return identity, (self.delayed_container._name, index) def identity(x): @@ -122,10 +121,10 @@ def from_delayed( if not isinstance(item, Delayed): raise TypeError("Expected Delayed object, got %s" % type(item).__name__) - dfs = [_DelayedExpr(df) for df in dfs] - from dask_expr._collection import new_collection return new_collection( - FromDelayed(make_meta(meta), divisions, verify_meta, None, prefix, *dfs) + FromDelayed( + DelayedsExpr(*dfs), make_meta(meta), divisions, verify_meta, None, prefix + ) ) diff --git a/dask_expr/io/tests/test_delayed.py b/dask_expr/io/tests/test_delayed.py index 61f8397e..6e0f3230 100644 --- a/dask_expr/io/tests/test_delayed.py +++ b/dask_expr/io/tests/test_delayed.py @@ -26,6 +26,7 @@ def test_from_delayed(prefix): df = from_delayed(dfs, meta=pdf.head(0), divisions=None, prefix=prefix) assert_eq(df, pdf) + assert len({k[0] for k in df.optimize().dask}) == 2 if prefix: assert df._name.startswith(prefix)