From fae5c6eada640c1cb109af33a2e6cb10d5a5d568 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 15 Feb 2024 23:24:17 +0100 Subject: [PATCH 01/32] Implement branch_id to limit reuse --- dask_expr/_concat.py | 2 +- dask_expr/_core.py | 34 ++++++++++++++++++--- dask_expr/_cumulative.py | 2 +- dask_expr/_expr.py | 47 ++++++++++++++++++------------ dask_expr/_groupby.py | 5 ++-- dask_expr/_reductions.py | 32 +++++++++++++++++++- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/tests/test_collection.py | 11 +++++-- dask_expr/tests/test_groupby.py | 2 +- 10 files changed, 105 insertions(+), 34 deletions(-) diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index f99503121..fcdb4997d 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -45,7 +45,7 @@ def __str__(self): + ", " + ", ".join( str(param) + "=" + str(operand) - for param, operand in zip(self._parameters, self.operands) + for param, operand in zip(self._parameters[:-1], self.operands) if operand != self._defaults.get(param) ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 3aac656c4..8b4a07406 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -5,6 +5,7 @@ import weakref from collections import defaultdict from collections.abc import Generator +from typing import NamedTuple import dask import pandas as pd @@ -15,6 +16,10 @@ from dask_expr._util import _BackendData, _tokenize_deterministic +class BranchId(NamedTuple): + branch_id: int | None + + def _unpack_collections(o): if isinstance(o, Expr): return o @@ -30,8 +35,13 @@ class Expr: _defaults = {} _instances = weakref.WeakValueDictionary() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, _branch_id=None, **kwargs): operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + elif _branch_id is None: + _branch_id = BranchId(None) + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) @@ -39,7 +49,8 @@ def __new__(cls, *args, **kwargs): operands.append(cls._defaults[parameter]) assert not kwargs, kwargs inst = object.__new__(cls) - inst.operands = [_unpack_collections(o) for o in operands] + inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] + inst._parameters = cls._parameters + ["_branch_id"] _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -47,6 +58,10 @@ def __new__(cls, *args, **kwargs): Expr._instances[_name] = inst return inst + @functools.cached_property + def argument_operands(self): + return self.operands[:-1] + def _tune_down(self): return None @@ -144,7 +159,7 @@ def _depth(self): def operand(self, key): # Access an operand unambiguously # (e.g. if the key is reserved by a method/property) - return self.operands[type(self)._parameters.index(key)] + return self.operands[self._parameters.index(key)] def dependencies(self): # Dependencies are `Expr` operands only @@ -267,6 +282,11 @@ def rewrite(self, kind: str): return expr + def _push_branch_id(self, parent): + if self._branch_id.branch_id != parent._branch_id.branch_id: + result = type(self)(*self.operands[:-1], parent._branch_id) + return parent.substitute(self, result) + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -303,7 +323,9 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): # Allow children to simplify their parents for child in expr.dependencies(): - out = child._simplify_up(expr, dependents) + out = child._push_branch_id(expr) + if out is None: + out = child._simplify_up(expr, dependents) if out is None: out = expr @@ -419,6 +441,10 @@ def _name(self): def _meta(self): raise NotImplementedError() + @functools.cached_property + def _branch_id(self): + return self.operands[-1] + def __getattr__(self, key): try: return object.__getattribute__(self, key) diff --git a/dask_expr/_cumulative.py b/dask_expr/_cumulative.py index 8d8b638ba..1aedd8bc9 100644 --- a/dask_expr/_cumulative.py +++ b/dask_expr/_cumulative.py @@ -47,7 +47,7 @@ def operation(self): @functools.cached_property def _args(self) -> list: - return self.operands[:-1] + return self.argument_operands[:-1] class TakeLast(Blockwise): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 6d800eab6..670735ba5 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -466,7 +466,7 @@ def _kwargs(self) -> dict: if self._keyword_only: return { p: self.operand(p) - for p in self._parameters + for p in self._parameters[:-1] if p in self._keyword_only and self.operand(p) is not no_default } return {} @@ -475,10 +475,12 @@ def _kwargs(self) -> dict: def _args(self) -> list: if self._keyword_only: args = [ - self.operand(p) for p in self._parameters if p not in self._keyword_only - ] + self.operands[len(self._parameters) :] + self.operand(p) + for p in self._parameters[:-1] + if p not in self._keyword_only + ] + self.argument_operands[len(self._parameters) - 1 :] return args - return self.operands + return self.argument_operands def _broadcast_dep(self, dep: Expr): # Checks if a dependency should be broadcasted to @@ -562,7 +564,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.operands[len(self._parameters) :] + return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _meta(self): @@ -658,7 +660,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _dfs(self): @@ -725,7 +727,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.operands[len(self._parameters) :] + for arg in self.argument_operands[len(self._parameters) - 1 :] ] return _get_meta_map_partitions( args, @@ -737,11 +739,11 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.operands[len(self._parameters) :] + args = [self.frame] + self.argument_operands[len(self._parameters) - 1 :] return calc_divisions_for_align(*args) def _lower(self): - args = [self.frame] + self.operands[len(self._parameters) :] + args = [self.frame] + self.operands[len(self._parameters) - 1 :] args = maybe_align_partitions(*args, divisions=self._divisions()) return MapOverlap( args[0], @@ -792,7 +794,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.operands[len(self._parameters) :] + + self.argument_operands[len(self._parameters) - 1 :] ) @functools.cached_property @@ -800,7 +802,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.operands[len(self._parameters) :] + for arg in self.argument_operands[len(self._parameters) - 1 :] ] return _get_meta_map_partitions( args, @@ -1094,7 +1096,11 @@ class Sample(Blockwise): @functools.cached_property def _meta(self): - args = [self.operands[0]._meta] + [self.operands[1][0]] + self.operands[2:] + args = ( + [self.operands[0]._meta] + + [self.operands[1][0]] + + self.argument_operands[2:] + ) return self.operation(*args) def _task(self, index: int): @@ -1696,11 +1702,11 @@ class Assign(Elemwise): @functools.cached_property def keys(self): - return self.operands[1::2] + return self.argument_operands[1::2] @functools.cached_property def vals(self): - return self.operands[2::2] + return self.argument_operands[2::2] @functools.cached_property def _meta(self): @@ -1725,7 +1731,7 @@ def _simplify_down(self): if self._check_for_previously_created_column(self.frame): # don't squash if we are using a column that was previously created return - return Assign(*self.frame.operands, *self.operands[1:]) + return Assign(*self.frame.argument_operands, *self.operands[1:]) def _check_for_previously_created_column(self, child): input_columns = [] @@ -1753,6 +1759,7 @@ def _simplify_up(self, parent, dependents): for k, v in zip(self.keys, self.vals): if k in columns: new_args.extend([k, v]) + new_args.append(self._branch_id) else: new_args = self.operands[1:] @@ -1779,12 +1786,12 @@ class CaseWhen(Elemwise): @functools.cached_property def caselist(self): - c = self.operands[1:] + c = self.argument_operands[1:] return [(c[i], c[i + 1]) for i in range(0, len(c), 2)] @functools.cached_property def _meta(self): - c = self.operands[1:] + c = self.argument_operands[1:] caselist = [ ( meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i], @@ -3308,7 +3315,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _dfs(self): @@ -3323,7 +3330,9 @@ def _meta(self): def _lower(self): args = maybe_align_partitions(*self.args, divisions=self._divisions()) dfs = [x for x in args if isinstance(x, Expr) and x.ndim > 0] - return UFuncElemwise(dfs[0], self.func, self._meta, False, self.kwargs, *args) + return UFuncElemwise( + dfs[0], self.func, self._meta, False, self.kwargs, *args, self._branch_id + ) class Fused(Blockwise): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index faed278b3..9827092ce 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def levels(self): @@ -738,9 +738,10 @@ def _lower(self): self.operand(param) if param not in ("chunk_kwargs", "aggregate_kwargs") else {} - for param in self._parameters + for param in self._parameters[:-1] ], *self.by, + self._branch_id, ) if is_dataframe_like(s._meta): c = c[s.columns] diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 3b562cfc3..3993512ff 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -26,6 +26,7 @@ from dask.utils import M, apply, funcname from dask_expr._concat import Concat +from dask_expr._core import BranchId from dask_expr._expr import ( Blockwise, Expr, @@ -97,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): @@ -507,6 +508,35 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) + def _push_branch_id(self, parent): + return + + def _simplify_down(self): + if self._branch_id.branch_id is not None: + return + + seen = set() + stack = self.dependencies() + counter, found_io = 1, False + + while stack: + node = stack.pop() + + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, ApplyConcatApply): + counter += 1 + continue + deps = node.dependencies() + if not deps: + found_io = True + stack.extend(deps) + if not found_io: + return + return type(self)(*self.operands[:-1], BranchId(counter)) + class Unique(ApplyConcatApply): _parameters = ["frame", "split_every", "split_out", "shuffle_method"] diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index 97748c991..dbf99b727 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.operands[len(self._parameters) :] + return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 9036edf69..26ee48140 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _meta(self): diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 3404f213d..9d8948ef2 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -942,7 +942,7 @@ def test_repr(df): s = (df["x"] + 1).sum(skipna=False).expr assert '["x"]' in str(s) or "['x']" in str(s) assert "+ 1" in str(s) - assert "sum(skipna=False)" in str(s) + assert "sum(skipna=False" in str(s) @xfail_gpu("combine_first not supported by cudf") @@ -1924,7 +1924,9 @@ def test_assign_squash_together(df, pdf): pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) - assert "Assign: a=1, b=2" in [line for line in result.expr._tree_repr_lines()] + assert "Assign: a=1, b=2, ranchId(branch_id=0=" in [ + line for line in result.expr._tree_repr_lines() + ] def test_are_co_aligned(pdf, df): @@ -2412,7 +2414,10 @@ def test_filter_optimize_condition(): def test_scalar_repr(df): result = repr(df.size) - assert result == "" + assert ( + result + == "" + ) def test_reset_index_filter_pushdown(df): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 44e7236fc..c8e78afaf 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -34,7 +34,7 @@ def test_groupby_unsupported_by(pdf, df): @pytest.mark.parametrize("split_every", [None, 5]) @pytest.mark.parametrize( "api", - ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], + ["mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], ) @pytest.mark.parametrize( "numeric_only", From d598734b317b87bd83c3d56b2148743e00f67f0d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 00:23:35 +0100 Subject: [PATCH 02/32] Update --- dask_expr/_concat.py | 2 +- dask_expr/_core.py | 11 +++++------ dask_expr/_expr.py | 28 +++++++++++++--------------- dask_expr/_groupby.py | 4 ++-- dask_expr/_reductions.py | 2 +- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/tests/test_groupby.py | 15 ++++++++------- 8 files changed, 32 insertions(+), 34 deletions(-) diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index fcdb4997d..f99503121 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -45,7 +45,7 @@ def __str__(self): + ", " + ", ".join( str(param) + "=" + str(operand) - for param, operand in zip(self._parameters[:-1], self.operands) + for param, operand in zip(self._parameters, self.operands) if operand != self._defaults.get(param) ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 8b4a07406..756b3098c 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -50,7 +50,6 @@ def __new__(cls, *args, _branch_id=None, **kwargs): assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] - inst._parameters = cls._parameters + ["_branch_id"] _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -62,6 +61,10 @@ def __new__(cls, *args, _branch_id=None, **kwargs): def argument_operands(self): return self.operands[:-1] + @functools.cached_property + def _branch_id(self): + return self.operands[-1] + def _tune_down(self): return None @@ -159,7 +162,7 @@ def _depth(self): def operand(self, key): # Access an operand unambiguously # (e.g. if the key is reserved by a method/property) - return self.operands[self._parameters.index(key)] + return self.operands[type(self)._parameters.index(key)] def dependencies(self): # Dependencies are `Expr` operands only @@ -441,10 +444,6 @@ def _name(self): def _meta(self): raise NotImplementedError() - @functools.cached_property - def _branch_id(self): - return self.operands[-1] - def __getattr__(self, key): try: return object.__getattribute__(self, key) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 670735ba5..2d7e4afb1 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -466,7 +466,7 @@ def _kwargs(self) -> dict: if self._keyword_only: return { p: self.operand(p) - for p in self._parameters[:-1] + for p in self._parameters if p in self._keyword_only and self.operand(p) is not no_default } return {} @@ -475,10 +475,8 @@ def _kwargs(self) -> dict: def _args(self) -> list: if self._keyword_only: args = [ - self.operand(p) - for p in self._parameters[:-1] - if p not in self._keyword_only - ] + self.argument_operands[len(self._parameters) - 1 :] + self.operand(p) for p in self._parameters if p not in self._keyword_only + ] + self.argument_operands[len(self._parameters) :] return args return self.argument_operands @@ -564,7 +562,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + return [self.frame] + self.argument_operands[len(self._parameters) :] @functools.cached_property def _meta(self): @@ -660,7 +658,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -727,7 +725,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) - 1 :] + for arg in self.argument_operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -739,11 +737,11 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + args = [self.frame] + self.argument_operands[len(self._parameters) :] return calc_divisions_for_align(*args) def _lower(self): - args = [self.frame] + self.operands[len(self._parameters) - 1 :] + args = [self.frame] + self.operands[len(self._parameters) :] args = maybe_align_partitions(*args, divisions=self._divisions()) return MapOverlap( args[0], @@ -794,7 +792,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.argument_operands[len(self._parameters) - 1 :] + + self.argument_operands[len(self._parameters) :] ) @functools.cached_property @@ -802,7 +800,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) - 1 :] + for arg in self.argument_operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -1908,7 +1906,7 @@ def _simplify_down(self): and self._meta.ndim == self.frame._meta.ndim ): # TODO: we should get more precise around Expr.columns types - return self.frame + return type(self.frame)(*self.frame.argument_operands, self._branch_id) if isinstance(self.frame, Projection): # df[a][b] a = self.frame.operand("columns") @@ -1922,7 +1920,7 @@ def _simplify_down(self): else: assert b in a - return self.frame.frame[b] + return type(self)(self.frame.frame, b, *self.operands[2:]) class Index(Elemwise): @@ -3315,7 +3313,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _dfs(self): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 9827092ce..2dc3b9440 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def levels(self): @@ -738,7 +738,7 @@ def _lower(self): self.operand(param) if param not in ("chunk_kwargs", "aggregate_kwargs") else {} - for param in self._parameters[:-1] + for param in self._parameters ], *self.by, self._branch_id, diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 3993512ff..629d88c82 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -98,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index dbf99b727..cae2a5436 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + return [self.frame] + self.argument_operands[len(self._parameters) :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 26ee48140..70881b9c0 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _meta(self): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index c8e78afaf..62b0e1b81 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -34,7 +34,7 @@ def test_groupby_unsupported_by(pdf, df): @pytest.mark.parametrize("split_every", [None, 5]) @pytest.mark.parametrize( "api", - ["mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], + ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], ) @pytest.mark.parametrize( "numeric_only", @@ -95,18 +95,19 @@ def test_groupby_numeric(pdf, df, api, numeric_only, split_every): def test_groupby_reduction_optimize(pdf, df): df = df.replace(1, 5) - agg = df.groupby(df.x).y.sum() - expected_query = df[["x", "y"]] - expected_query = expected_query.groupby(expected_query.x).y.sum() - assert agg.optimize()._name == expected_query.optimize()._name - expect = pdf.replace(1, 5).groupby(["x"]).y.sum() - assert_eq(agg, expect) + # agg = df.groupby(df.x).y.sum() + # expected_query = df[["x", "y"]] + # expected_query = expected_query.groupby(expected_query.x).y.sum() + # assert agg.optimize()._name == expected_query.optimize()._name + # expect = pdf.replace(1, 5).groupby(["x"]).y.sum() + # assert_eq(agg, expect) df2 = df[["y"]] agg = df2.groupby(df.x).y.sum() ops = [ op for op in agg.expr.optimize(fuse=False).walk() if isinstance(op, FromPandas) ] + agg.simplify().pprint() assert len(ops) == 1 assert ops[0].columns == ["x", "y"] From d88270d20f41db6d77f0f76a13b2229df074e65c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 01:57:44 +0100 Subject: [PATCH 03/32] Fix delayed --- dask_expr/_expr.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 15f510fc8..e268820b7 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -53,6 +53,7 @@ from tlz import merge_sorted, partition, unique from dask_expr import _core as core +from dask_expr._core import BranchId from dask_expr._util import ( _calc_maybe_new_divisions, _convert_to_list, @@ -2719,9 +2720,11 @@ class _DelayedExpr(Expr): # TODO _parameters = ["obj"] - def __init__(self, obj): + def __init__(self, obj, _branch_id=None): self.obj = obj - self.operands = [obj] + if _branch_id is None: + _branch_id = BranchId(None) + self.operands = [obj, _branch_id] def __str__(self): return f"{type(self).__name__}({str(self.obj)})" From 948cd83189f21b617f19814cdc2d106f1fc70537 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 14:38:52 +0100 Subject: [PATCH 04/32] Update --- dask_expr/_core.py | 35 +++++++++++++++++++++++------- dask_expr/_expr.py | 26 +++++++++++++++++----- dask_expr/_reductions.py | 16 ++++++++++---- dask_expr/io/parquet.py | 2 +- dask_expr/tests/test_collection.py | 4 +--- dask_expr/tests/test_shuffle.py | 4 ++-- 6 files changed, 63 insertions(+), 24 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 756b3098c..040714186 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -40,7 +40,7 @@ def __new__(cls, *args, _branch_id=None, **kwargs): if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): _branch_id = operands.pop(-1) elif _branch_id is None: - _branch_id = BranchId(None) + _branch_id = BranchId(0) for parameter in cls._parameters[len(operands) :]: try: @@ -125,6 +125,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): op = "" elif is_arraylike(op): op = "" + elif isinstance(op, BranchId): + if op.branch_id == 0: + continue + op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) lines = [header] + lines @@ -285,10 +289,27 @@ def rewrite(self, kind: str): return expr - def _push_branch_id(self, parent): - if self._branch_id.branch_id != parent._branch_id.branch_id: - result = type(self)(*self.operands[:-1], parent._branch_id) - return parent.substitute(self, result) + def _reuse_up(self, parent): + return + + def _reuse_down(self): + if not self.dependencies(): + return + return self._bubble_branch_id_down() + + def _bubble_branch_id_down(self): + b_id = self._branch_id + if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): + ops = [ + op._substitute_branch_id(b_id) if isinstance(op, Expr) else op + for op in self.argument_operands + ] + return type(self)(*ops) + + def _substitute_branch_id(self, branch_id): + if self._branch_id.branch_id != 0: + return self + return type(self)(*self.argument_operands, branch_id) def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -326,9 +347,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): # Allow children to simplify their parents for child in expr.dependencies(): - out = child._push_branch_id(expr) - if out is None: - out = child._simplify_up(expr, dependents) + out = child._simplify_up(expr, dependents) if out is None: out = expr diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index e268820b7..10475f198 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1757,9 +1757,8 @@ def _simplify_up(self, parent, dependents): for k, v in zip(self.keys, self.vals): if k in columns: new_args.extend([k, v]) - new_args.append(self._branch_id) else: - new_args = self.operands[1:] + new_args = self.argument_operands[1:] columns = [col for col in self.frame.columns if col in cols] return type(parent)( @@ -2723,7 +2722,7 @@ class _DelayedExpr(Expr): def __init__(self, obj, _branch_id=None): self.obj = obj if _branch_id is None: - _branch_id = BranchId(None) + _branch_id = BranchId(0) self.operands = [obj, _branch_id] def __str__(self): @@ -2752,7 +2751,9 @@ def normalize_expression(expr): return expr._name -def optimize(expr: Expr, fuse: bool = True) -> Expr: +def optimize( + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = True +) -> Expr: """High level query optimization This leverages three optimization passes: @@ -2766,15 +2767,28 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: Input expression to optimize fuse: whether or not to turn on blockwise fusion + common_subplan_elimination : bool + whether we want to reuse common subplans that are found in the graph and + are used in self-joins or similar which require all data be held in memory + at some point. Only set this to false if your dataset fits into memory. See Also -------- simplify optimize_blockwise_fusion """ + result = expr + while True: + if common_subplan_elimination: + out = result.rewrite("reuse") + else: + out = result + out = out.simplify() + if out._name == result._name or not common_subplan_elimination: + break + result = out - # Simplify - result = expr.simplify() + result = out # Manipulate Expression to make it more efficient result = result.rewrite(kind="tune") diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index c699e903a..6a1ac5cf3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -508,11 +508,14 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) - def _push_branch_id(self, parent): + def _reuse_up(self, parent): return - def _simplify_down(self): - if self._branch_id.branch_id is not None: + def _substitute_branch_id(self, branch_id): + return self + + def _reuse_down(self): + if self._branch_id.branch_id != 0: return seen = set() @@ -535,7 +538,12 @@ def _simplify_down(self): stack.extend(deps) if not found_io: return - return type(self)(*self.operands[:-1], BranchId(counter)) + b_id = BranchId(counter) + result = type(self)(*self.argument_operands, b_id) + out = result._bubble_branch_id_down() + if out is None: + return result + return type(out)(*out.argument_operands, b_id) class Unique(ApplyConcatApply): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3d..472f58ec7 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -454,7 +454,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): _absorb_projections = True def _tree_repr_argument_construction(self, i, op, header): - if self._parameters[i] == "_dataset_info_cache": + if i < len(self._parameters) and self._parameters[i] == "_dataset_info_cache": # Don't print this, very ugly return header return super()._tree_repr_argument_construction(i, op, header) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 9d8948ef2..02024420d 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -1924,9 +1924,7 @@ def test_assign_squash_together(df, pdf): pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) - assert "Assign: a=1, b=2, ranchId(branch_id=0=" in [ - line for line in result.expr._tree_repr_lines() - ] + assert "Assign: a=1, b=2" in [line for line in result.expr._tree_repr_lines()] def test_are_co_aligned(pdf, df): diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index c7b977284..45338c2c5 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum()._name + assert df.shuffle("x").sum().simplify()._name == df.sum().simplify()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -716,7 +716,7 @@ def test_sort_values_avoid_overeager_filter_pushdown(meth): pdf1 = pd.DataFrame({"a": [4, 2, 3], "b": [1, 2, 3]}) df = from_pandas(pdf1, npartitions=2) df = getattr(df, meth)("a") - df = df[df.b > 2] + df.b.sum() + df = df[df.b > 2] + df[df.b > 1] result = df.simplify() assert isinstance(result.expr.left, Filter) assert isinstance(result.expr.left.frame, BaseSetIndexSortValues) From 045bbef7f3fa53dda837f87105f59d91e71a658d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:43:03 +0100 Subject: [PATCH 05/32] Update --- dask_expr/tests/test_collection.py | 5 +---- dask_expr/tests/test_groupby.py | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 02024420d..e60debc12 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2412,10 +2412,7 @@ def test_filter_optimize_condition(): def test_scalar_repr(df): result = repr(df.size) - assert ( - result - == "" - ) + assert result == "" def test_reset_index_filter_pushdown(df): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 62b0e1b81..be0d21df6 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -95,12 +95,12 @@ def test_groupby_numeric(pdf, df, api, numeric_only, split_every): def test_groupby_reduction_optimize(pdf, df): df = df.replace(1, 5) - # agg = df.groupby(df.x).y.sum() - # expected_query = df[["x", "y"]] - # expected_query = expected_query.groupby(expected_query.x).y.sum() - # assert agg.optimize()._name == expected_query.optimize()._name - # expect = pdf.replace(1, 5).groupby(["x"]).y.sum() - # assert_eq(agg, expect) + agg = df.groupby(df.x).y.sum() + expected_query = df[["x", "y"]] + expected_query = expected_query.groupby(expected_query.x).y.sum() + assert agg.optimize()._name == expected_query.optimize()._name + expect = pdf.replace(1, 5).groupby(["x"]).y.sum() + assert_eq(agg, expect) df2 = df[["y"]] agg = df2.groupby(df.x).y.sum() From 93e0d28d32f423c8c99c0975b7552a33aaf449a5 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:45:44 +0100 Subject: [PATCH 06/32] Add cache --- dask_expr/_core.py | 8 ++++++-- dask_expr/_expr.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 040714186..ba53e9953 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -225,7 +225,7 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def rewrite(self, kind: str): + def rewrite(self, kind: str, cache): """Rewrite an expression This leverages the ``._{kind}_down`` and ``._{kind}_up`` @@ -238,6 +238,9 @@ def rewrite(self, kind: str): changed: whether or not any change occured """ + if self._name in cache: + return cache[self._name] + expr = self down_name = f"_{kind}_down" up_name = f"_{kind}_up" @@ -274,7 +277,8 @@ def rewrite(self, kind: str): changed = False for operand in expr.operands: if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) + new = operand.rewrite(kind=kind, cache=cache) + cache[operand._name] = new if new._name != operand._name: changed = True else: diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 10475f198..ce8e4a42d 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2780,7 +2780,7 @@ def optimize( result = expr while True: if common_subplan_elimination: - out = result.rewrite("reuse") + out = result.rewrite("reuse", cache={}) else: out = result out = out.simplify() @@ -2791,13 +2791,13 @@ def optimize( result = out # Manipulate Expression to make it more efficient - result = result.rewrite(kind="tune") + result = result.rewrite(kind="tune", cache={}) # Lower result = result.lower_completely() # Cull - result = result.rewrite(kind="cull") + result = result.rewrite(kind="cull", cache={}) # Final graph-specific optimizations if fuse: From 7ddda990b0e67ee91addaaf5a9e80f9168637bf8 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:57:33 +0100 Subject: [PATCH 07/32] Enhance tests --- dask_expr/tests/test_collection.py | 42 +++++++++++++++++++----------- dask_expr/tests/test_shuffle.py | 10 +++++-- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index e60debc12..1fa8b6114 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -507,7 +507,7 @@ def test_diff(pdf, df, axis, periods): if axis in ("columns", 1): assert actual._name == actual.simplify()._name else: - assert actual.simplify()._name == expected.simplify()._name + assert actual.optimize()._name == expected.optimize()._name @pytest.mark.parametrize( @@ -1163,8 +1163,8 @@ def test_tail_repartition(df): def test_projection_stacking(df): result = df[["x", "y"]]["x"] - optimized = result.simplify() - expected = df["x"].simplify() + optimized = result.optimize() + expected = df["x"].optimize() assert optimized._name == expected._name @@ -1877,8 +1877,8 @@ def test_assign_simplify(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x", "new"]].simplify() - expected = df2[["x"]].assign(new=df2.x > 1).simplify() + result = df[["x", "new"]].optimize() + expected = df2[["x"]].assign(new=df2.x > 1).optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1889,8 +1889,8 @@ def test_assign_simplify_new_column_not_needed(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x"]].simplify() - expected = df2[["x"]].simplify() + result = df[["x"]].optimize() + expected = df2[["x"]].optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1901,8 +1901,8 @@ def test_assign_simplify_series(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df.new.simplify() - expected = df2[[]].assign(new=df2.x > 1).new.simplify() + result = df.new.optimize() + expected = df2[[]].assign(new=df2.x > 1).new.optimize() assert result._name == expected._name @@ -1920,7 +1920,16 @@ def test_assign_squash_together(df, pdf): df["a"] = 1 df["b"] = 2 result = df.simplify() - assert len([x for x in list(result.expr.walk()) if isinstance(x, expr.Assign)]) == 1 + assert ( + len( + [ + x + for x in list(df.optimize(fuse=False).expr.walk()) + if isinstance(x, expr.Assign) + ] + ) + == 1 + ) pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) @@ -1970,10 +1979,10 @@ def test_astype_categories(df): assert_eq(result.y._meta.cat.categories, pd.Index([UNKNOWN_CATEGORIES])) -def test_drop_simplify(df): +def test_drop_optimize(df): q = df.drop(columns=["x"])[["y"]] - result = q.simplify() - expected = df[["y"]].simplify() + result = q.optimize() + expected = df[["y"]].optimize() assert result._name == expected._name @@ -2061,6 +2070,7 @@ def test_filter_pushdown_unavailable(df): result = df[df.x > 5] + df.x.sum() result = result[["x"]] expected = df[["x"]][df.x > 5] + df.x.sum() + assert result.optimize()._name == expected.optimize()._name assert result.simplify()._name == expected.simplify()._name @@ -2073,6 +2083,7 @@ def test_filter_pushdown(df, pdf): df = df.rename_axis(index="hello") result = df[df.x > 5].simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name pdf["z"] = 1 df = from_pandas(pdf, npartitions=10) @@ -2081,6 +2092,7 @@ def test_filter_pushdown(df, pdf): df_opt = df[["x", "y"]] expected = df_opt[df_opt.x > 5].rename_axis(index="hello").simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name def test_shape(df, pdf): @@ -2430,13 +2442,13 @@ def test_reset_index_filter_pushdown(df): result = q[q > 5] expected = df["x"] expected = expected[expected > 5].reset_index(drop=True) - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name q = df.x.reset_index() result = q[q.x > 5] expected = df["x"] expected = expected[expected > 5].reset_index() - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_astype_filter_pushdown(df, pdf): diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 45338c2c5..3c6d6c6f3 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum().simplify()._name + assert df.shuffle("x").sum().optimize()._name == df.sum().optimize()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -264,7 +264,7 @@ def test_set_index_repartition(df, pdf): assert_eq(result, pdf.set_index("x")) -def test_set_index_simplify(df, pdf): +def test_set_index_optimize(df, pdf): q = df.set_index("x")["y"].optimize(fuse=False) expected = df[["x", "y"]].set_index("x")["y"].optimize(fuse=False) assert q._name == expected._name @@ -697,18 +697,21 @@ def test_shuffle_filter_pushdown(pdf, meth): result = result[result.x > 5.0] expected = getattr(df[df.x > 5.0], meth)("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x") result = result[result.x > 5.0][["x", "y"]] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x")[["x", "y"]] result = result[result.x > 5.0] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name @pytest.mark.parametrize("meth", ["set_index", "sort_values"]) @@ -729,15 +732,18 @@ def test_set_index_filter_pushdown(): result = result[result.y == 1] expected = df[df.y == 1].set_index("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x") result = result[result.y == 1][["y"]] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x")[["y"]] result = result[result.y == 1] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name From 8c2d97726e0ddda6f24001f825bda14698966351 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 17:25:54 +0100 Subject: [PATCH 08/32] Add tests --- dask_expr/tests/test_reuse.py | 82 +++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 dask_expr/tests/test_reuse.py diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py new file mode 100644 index 000000000..b6c2378f8 --- /dev/null +++ b/dask_expr/tests/test_reuse.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import pytest + +from dask_expr import from_pandas +from dask_expr.io import IO +from dask_expr.tests._util import _backend_library, assert_eq + +# Set DataFrame backend for this module +pd = _backend_library() + + +@pytest.fixture +def pdf(): + pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 1, "c": 1}) + pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions + yield pdf + + +@pytest.fixture +def df(pdf): + yield from_pandas(pdf, npartitions=10) + + +def _check_io_nodes(expr, expected): + assert len(list(expr.find_operations(IO))) == expected + + +def test_reuse_everything_scalar_and_series(df, pdf): + df["new"] = 1 + df["new2"] = df["x"] + 1 + df["new3"] = df.x[df.x > 1] + df.x[df.x > 2] + + pdf["new"] = 1 + pdf["new2"] = pdf["x"] + 1 + pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] + assert_eq(df, pdf) + _check_io_nodes(df.optimize(fuse=False), 1) + + +def test_dont_reuse_reducer(df, pdf): + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + assert_eq(result, expected) + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df + df.sum() + expected = pdf + pdf.sum() + assert_eq(result, expected) + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df.replace(1, 5) + rhs_1 = result.x + result.y.sum() + rhs_2 = result.b + result.a.sum() + result["new"] = rhs_1 + result["new2"] = rhs_2 + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + result["new2"] = result.b + result.a.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 3) + + result = df.replace(1, 5) + result["new"] = result.x + result.sum().dropna().prod() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.sum().dropna().prod() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 2) From 7184bcfbf6c0215bedf6ac5764302a429025a6c4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:10:50 +0100 Subject: [PATCH 09/32] Update --- dask_expr/_expr.py | 8 +++----- dask_expr/tests/test_reuse.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index ce8e4a42d..c9e566140 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1905,7 +1905,7 @@ def _simplify_down(self): and self._meta.ndim == self.frame._meta.ndim ): # TODO: we should get more precise around Expr.columns types - return type(self.frame)(*self.frame.argument_operands, self._branch_id) + return self.frame if isinstance(self.frame, Projection): # df[a][b] a = self.frame.operand("columns") @@ -1919,7 +1919,7 @@ def _simplify_down(self): else: assert b in a - return type(self)(self.frame.frame, b, *self.operands[2:]) + return self.frame.frame[b] class Index(Elemwise): @@ -3344,9 +3344,7 @@ def _meta(self): def _lower(self): args = maybe_align_partitions(*self.args, divisions=self._divisions()) dfs = [x for x in args if isinstance(x, Expr) and x.ndim > 0] - return UFuncElemwise( - dfs[0], self.func, self._meta, False, self.kwargs, *args, self._branch_id - ) + return UFuncElemwise(dfs[0], self.func, self._meta, False, self.kwargs, *args) class Fused(Blockwise): diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index b6c2378f8..d1da88bca 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -60,7 +60,6 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 2) result = df.replace(1, 5) @@ -70,7 +69,6 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 3) result = df.replace(1, 5) @@ -78,5 +76,4 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 2) From 5ac93949dde3ddf3836e2bec391af795a9f65acf Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:18:34 +0100 Subject: [PATCH 10/32] Update --- dask_expr/tests/test_reuse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index d1da88bca..850b6a493 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,6 +48,7 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() + result.optimize().pprint() assert_eq(result, expected) _check_io_nodes(result.optimize(fuse=False), 2) From fb2aa9f0df223682674d0c21d35ef5df1de4b47b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:22:21 +0100 Subject: [PATCH 11/32] Update --- dask_expr/_reductions.py | 2 +- dask_expr/tests/test_reuse.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 6a1ac5cf3..db65ba5f9 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -347,7 +347,7 @@ def _layer(self): @property def _meta(self): - return self.operand("_meta") + return make_meta(self.operand("_meta")) def _divisions(self): return (None, None) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index 850b6a493..d1da88bca 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,7 +48,6 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() - result.optimize().pprint() assert_eq(result, expected) _check_io_nodes(result.optimize(fuse=False), 2) From 061de6f1ef136e25047628024d6a8f2cd6887be4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:26:02 +0100 Subject: [PATCH 12/32] Update --- dask_expr/_reductions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index db65ba5f9..6a1ac5cf3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -347,7 +347,7 @@ def _layer(self): @property def _meta(self): - return make_meta(self.operand("_meta")) + return self.operand("_meta") def _divisions(self): return (None, None) From e4865905c1c040489cce93d76cc8f32ea2d18326 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:29:37 +0100 Subject: [PATCH 13/32] Update --- dask_expr/tests/test_reuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index d1da88bca..b95bace1a 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,7 +48,7 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() - assert_eq(result, expected) + assert_eq(result, expected, check_names=False) # pandas 2.2 bug _check_io_nodes(result.optimize(fuse=False), 2) result = df.replace(1, 5) From d28f906f10d316c869a1326e3a75d51809bb84c4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 11:57:20 +0100 Subject: [PATCH 14/32] Update _core.py --- dask_expr/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index ba53e9953..2d82aae79 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -17,7 +17,7 @@ class BranchId(NamedTuple): - branch_id: int | None + branch_id: int def _unpack_collections(o): From 366415a7345ee530f45ee89ef982429dbcc24992 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:05:47 +0100 Subject: [PATCH 15/32] Update test_groupby.py --- dask_expr/tests/test_groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index be0d21df6..44e7236fc 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -107,7 +107,6 @@ def test_groupby_reduction_optimize(pdf, df): ops = [ op for op in agg.expr.optimize(fuse=False).walk() if isinstance(op, FromPandas) ] - agg.simplify().pprint() assert len(ops) == 1 assert ops[0].columns == ["x", "y"] From ee523eaf730a797c7d4bc33df31d9e0c5ef41e83 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:11:15 +0100 Subject: [PATCH 16/32] Update --- dask_expr/_expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c9e566140..7ef0b8737 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2752,7 +2752,7 @@ def normalize_expression(expr): def optimize( - expr: Expr, fuse: bool = True, common_subplan_elimination: bool = True + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = False ) -> Expr: """High level query optimization @@ -2767,7 +2767,7 @@ def optimize( Input expression to optimize fuse: whether or not to turn on blockwise fusion - common_subplan_elimination : bool + common_subplan_elimination : bool, default False whether we want to reuse common subplans that are found in the graph and are used in self-joins or similar which require all data be held in memory at some point. Only set this to false if your dataset fits into memory. @@ -2779,12 +2779,12 @@ def optimize( """ result = expr while True: - if common_subplan_elimination: + if not common_subplan_elimination: out = result.rewrite("reuse", cache={}) else: out = result out = out.simplify() - if out._name == result._name or not common_subplan_elimination: + if out._name == result._name or common_subplan_elimination: break result = out From 369c142c38360fa9b1178f5af1a277e6dea768d0 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:11:59 +0100 Subject: [PATCH 17/32] Update --- dask_expr/_expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 7ef0b8737..b80e7c378 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2770,7 +2770,7 @@ def optimize( common_subplan_elimination : bool, default False whether we want to reuse common subplans that are found in the graph and are used in self-joins or similar which require all data be held in memory - at some point. Only set this to false if your dataset fits into memory. + at some point. Only set this to true if your dataset fits into memory. See Also -------- From 7379a01da6cdc65cf4e658525e9281498da4ccf6 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:27:58 +0100 Subject: [PATCH 18/32] Remove argument_operands --- dask_expr/_core.py | 19 +++++++----------- dask_expr/_cumulative.py | 2 +- dask_expr/_expr.py | 40 +++++++++++++++++--------------------- dask_expr/_groupby.py | 2 +- dask_expr/_reductions.py | 8 ++++---- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/io/io.py | 20 ++++++++++++++----- dask_expr/io/parquet.py | 2 +- 9 files changed, 49 insertions(+), 48 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index bfd7c2e1c..4a43926d4 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -49,7 +49,8 @@ def __new__(cls, *args, _branch_id=None, **kwargs): operands.append(cls._defaults[parameter]) assert not kwargs, kwargs inst = object.__new__(cls) - inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] + inst.operands = [_unpack_collections(o) for o in operands] + inst._branch_id = _branch_id _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -57,14 +58,6 @@ def __new__(cls, *args, _branch_id=None, **kwargs): Expr._instances[_name] = inst return inst - @functools.cached_property - def argument_operands(self): - return self.operands[:-1] - - @functools.cached_property - def _branch_id(self): - return self.operands[-1] - def _tune_down(self): return None @@ -300,14 +293,14 @@ def _bubble_branch_id_down(self): if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): ops = [ op._substitute_branch_id(b_id) if isinstance(op, Expr) else op - for op in self.argument_operands + for op in self.operands ] return type(self)(*ops) def _substitute_branch_id(self, branch_id): if self._branch_id.branch_id != 0: return self - return type(self)(*self.argument_operands, branch_id) + return type(self)(*self.operands, branch_id) def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -454,7 +447,9 @@ def _lower(self): @functools.cached_property def _name(self): return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) @property diff --git a/dask_expr/_cumulative.py b/dask_expr/_cumulative.py index 1aedd8bc9..8d8b638ba 100644 --- a/dask_expr/_cumulative.py +++ b/dask_expr/_cumulative.py @@ -47,7 +47,7 @@ def operation(self): @functools.cached_property def _args(self) -> list: - return self.argument_operands[:-1] + return self.operands[:-1] class TakeLast(Blockwise): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index dd594a287..94c09899c 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -477,9 +477,9 @@ def _args(self) -> list: if self._keyword_only: args = [ self.operand(p) for p in self._parameters if p not in self._keyword_only - ] + self.argument_operands[len(self._parameters) :] + ] + self.operands[len(self._parameters) :] return args - return self.argument_operands + return self.operands def _broadcast_dep(self, dep: Expr): # Checks if a dependency should be broadcasted to @@ -503,7 +503,7 @@ def _name(self): head = funcname(self.operation) else: head = funcname(type(self)).lower() - return head + "-" + _tokenize_deterministic(*self.operands) + return head + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def _blockwise_arg(self, arg, i): """Return a Blockwise-task argument""" @@ -563,7 +563,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.argument_operands[len(self._parameters) :] + return [self.frame] + self.operands[len(self._parameters) :] @functools.cached_property def _meta(self): @@ -659,7 +659,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -726,7 +726,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) :] + for arg in self.operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -738,7 +738,7 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.argument_operands[len(self._parameters) :] + args = [self.frame] + self.operands[len(self._parameters) :] return calc_divisions_for_align(*args) def _lower(self): @@ -793,7 +793,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.argument_operands[len(self._parameters) :] + + self.operands[len(self._parameters) :] ) @functools.cached_property @@ -801,7 +801,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) :] + for arg in self.operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -1095,11 +1095,7 @@ class Sample(Blockwise): @functools.cached_property def _meta(self): - args = ( - [self.operands[0]._meta] - + [self.operands[1][0]] - + self.argument_operands[2:] - ) + args = [self.operands[0]._meta] + [self.operands[1][0]] + self.operands[2:] return self.operation(*args) def _task(self, index: int): @@ -1701,11 +1697,11 @@ class Assign(Elemwise): @functools.cached_property def keys(self): - return self.argument_operands[1::2] + return self.operands[1::2] @functools.cached_property def vals(self): - return self.argument_operands[2::2] + return self.operands[2::2] @functools.cached_property def _meta(self): @@ -1730,7 +1726,7 @@ def _simplify_down(self): if self._check_for_previously_created_column(self.frame): # don't squash if we are using a column that was previously created return - return Assign(*self.frame.argument_operands, *self.operands[1:]) + return Assign(*self.frame.operands, *self.operands[1:]) def _check_for_previously_created_column(self, child): input_columns = [] @@ -1758,7 +1754,7 @@ def _simplify_up(self, parent, dependents): if k in columns: new_args.extend([k, v]) else: - new_args = self.argument_operands[1:] + new_args = self.operands[1:] columns = [col for col in self.frame.columns if col in cols] return type(parent)( @@ -1783,12 +1779,12 @@ class CaseWhen(Elemwise): @functools.cached_property def caselist(self): - c = self.argument_operands[1:] + c = self.operands[1:] return [(c[i], c[i + 1]) for i in range(0, len(c), 2)] @functools.cached_property def _meta(self): - c = self.argument_operands[1:] + c = self.operands[1:] caselist = [ ( meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i], @@ -3332,7 +3328,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -3424,7 +3420,7 @@ def __str__(self): @functools.cached_property def _name(self): - return f"{str(self)}-{_tokenize_deterministic(self.exprs)}" + return f"{str(self)}-{_tokenize_deterministic(self.exprs, self._branch_id)}" def _divisions(self): return self.exprs[0]._divisions() diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index c6ca9433e..4027b49ff 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def levels(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 6a1ac5cf3..1eef1eadb 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -98,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): @@ -301,7 +301,7 @@ def _name(self): name = funcname(self.combine.__self__).lower() + "-tree" else: name = funcname(self.combine) - return name + "-" + _tokenize_deterministic(*self.operands) + return name + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def __dask_postcompute__(self): return toolz.first, () @@ -539,11 +539,11 @@ def _reuse_down(self): if not found_io: return b_id = BranchId(counter) - result = type(self)(*self.argument_operands, b_id) + result = type(self)(*self.operands, b_id) out = result._bubble_branch_id_down() if out is None: return result - return type(out)(*out.argument_operands, b_id) + return type(out)(*out.operands, b_id) class Unique(ApplyConcatApply): diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index cae2a5436..97748c991 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.argument_operands[len(self._parameters) :] + return [self.frame] + self.operands[len(self._parameters) :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 70881b9c0..9036edf69 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _meta(self): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 1b6b34fe7..307132d2c 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -48,7 +48,9 @@ def _divisions(self): @functools.cached_property def _name(self): return ( - self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + self.operand("name_prefix") + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) def _layer(self): @@ -103,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) @functools.cached_property @@ -173,10 +175,14 @@ def _name(self): return ( funcname(self.func).lower() + "-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) else: - return self.label + "-" + _tokenize_deterministic(*self.operands) + return ( + self.label + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @functools.cached_property def _meta(self): @@ -448,7 +454,11 @@ class FromPandasDivisions(FromPandas): @functools.cached_property def _name(self): - return "from_pd_divs" + "-" + _tokenize_deterministic(*self.operands) + return ( + "from_pd_divs" + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @property def _divisions_and_locations(self): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 472f58ec7..7724cc53a 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -501,7 +501,7 @@ def _name(self): return ( funcname(type(self)).lower() + "-" - + _tokenize_deterministic(self.checksum, *self.operands) + + _tokenize_deterministic(self.checksum, *self.operands, self._branch_id) ) @property From 68e048c49b3d59272c2b5f7bd50115e91186b1e5 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:00:11 +0100 Subject: [PATCH 19/32] Update --- dask_expr/_core.py | 8 +++++--- dask_expr/_expr.py | 3 ++- dask_expr/_reductions.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 4a43926d4..2d5150cf8 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -273,7 +273,7 @@ def rewrite(self, kind: str, cache): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) continue else: break @@ -290,6 +290,8 @@ def _reuse_down(self): def _bubble_branch_id_down(self): b_id = self._branch_id + if b_id.branch_id <= 0: + return if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): ops = [ op._substitute_branch_id(b_id) if isinstance(op, Expr) else op @@ -366,7 +368,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) break @@ -411,7 +413,7 @@ def lower_once(self): new_operands.append(new) if changed: - out = type(out)(*new_operands) + out = type(out)(*new_operands, _branch_id=out._branch_id) return out diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 94c09899c..c01e5267f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2722,7 +2722,8 @@ def __init__(self, obj, _branch_id=None): self.obj = obj if _branch_id is None: _branch_id = BranchId(0) - self.operands = [obj, _branch_id] + self._branch_id = _branch_id + self.operands = [obj] def __str__(self): return f"{type(self).__name__}({str(self.obj)})" diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 1eef1eadb..ebd05e07b 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -543,7 +543,7 @@ def _reuse_down(self): out = result._bubble_branch_id_down() if out is None: return result - return type(out)(*out.operands, b_id) + return type(out)(*out.operands, _branch_id=b_id) class Unique(ApplyConcatApply): From f79155ac84d0c2b90cd08e1a4f288d02015dde92 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:56:53 +0100 Subject: [PATCH 20/32] Update --- dask_expr/_core.py | 7 +++++-- dask_expr/io/io.py | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 2d5150cf8..7983c756a 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -117,7 +117,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): continue op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) - + if self._branch_id.branch_id != 0: + header = self._tree_repr_argument_construction( + i + 1, f" branch_id={self._branch_id.branch_id}", header + ) lines = [header] + lines lines = [" " * indent + line for line in lines] @@ -604,7 +607,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr: else: new_operands.append(operand) if changed: - return type(self)(*new_operands) + return type(self)(*new_operands, _branch_id=self._branch_id) return self def _node_label_args(self): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 307132d2c..84d5df59a 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -105,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands, self._branch_id) + + _tokenize_deterministic(*self.operands, self._expr._branch_id) ) @functools.cached_property @@ -416,7 +416,8 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) if isinstance(parent, Projection): - return super()._simplify_up(parent, dependents) + x = super()._simplify_up(parent, dependents) + return x def _divisions(self): return self._divisions_and_locations[0] From 4801a9322aa3653f8f4e911cfac244f03a38b60f Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:37:01 +0100 Subject: [PATCH 21/32] Implement shuffles as consumer --- dask_expr/_core.py | 6 +++++- dask_expr/_expr.py | 2 +- dask_expr/_reductions.py | 14 ++++++++------ dask_expr/_shuffle.py | 18 ++++++++++++++++-- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index ba53e9953..6ee03150b 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -463,6 +463,10 @@ def _name(self): funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) ) + @functools.cached_property + def _dep_name(self): + return self._name + @property def _meta(self): raise NotImplementedError() @@ -774,5 +778,5 @@ def collect_dependents(expr) -> defaultdict: for dep in node.dependencies(): stack.append(dep) - dependents[dep._name].append(weakref.ref(node)) + dependents[dep._dep_name].append(weakref.ref(node)) return dependents diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c9e566140..a2b0bcc8b 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3472,7 +3472,7 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non column_union = [] else: column_union = parent.columns.copy() - parents = [x() for x in dependents[expr._name] if x() is not None] + parents = [x() for x in dependents[expr._dep_name] if x() is not None] seen = set() for p in parents: diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 6a1ac5cf3..f6d1bc7cd 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -508,9 +508,6 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) - def _reuse_up(self, parent): - return - def _substitute_branch_id(self, branch_id): return self @@ -518,6 +515,9 @@ def _reuse_down(self): if self._branch_id.branch_id != 0: return + from dask_expr._shuffle import Shuffle + from dask_expr.io import IO + seen = set() stack = self.dependencies() counter, found_io = 1, False @@ -532,10 +532,12 @@ def _reuse_down(self): if isinstance(node, ApplyConcatApply): counter += 1 continue - deps = node.dependencies() - if not deps: + + if isinstance(node, (IO, Shuffle)): found_io = True - stack.extend(deps) + continue + stack.extend(node.dependencies()) + if not found_io: return b_id = BranchId(counter) diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 3e70c96fc..98334421e 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -22,6 +22,7 @@ from dask.utils import ( M, digit, + funcname, get_default_shuffle_method, insert, is_index_like, @@ -63,7 +64,7 @@ ValueCounts, ) from dask_expr._repartition import Repartition, RepartitionToFewer -from dask_expr._util import LRU, _convert_to_list +from dask_expr._util import LRU, _convert_to_list, _tokenize_deterministic class ShuffleBase(Expr): @@ -157,6 +158,17 @@ def _meta(self): def _divisions(self): return (None,) * (self.npartitions_out + 1) + def _reuse_down(self): + return + + @functools.cached_property + def _dep_name(self): + return ( + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(*self.argument_operands) + ) + class Shuffle(ShuffleBase): """Abstract shuffle class @@ -196,6 +208,7 @@ def _lower(self): self.npartitions_out, self.ignore_index, self.options, + self._branch_id, ] if method == "p2p": return P2PShuffle(frame, *ops) @@ -290,6 +303,7 @@ def _lower(self): ignore_index, self.method, options, + self._branch_id, ) # Drop "_partitions" column and return @@ -549,7 +563,7 @@ def _layer(self): ) dsk = {} - token = self._name.split("-")[-1] + token = self._dep_name.split("-")[-1] _barrier_key = barrier_key(ShuffleId(token)) name = "shuffle-transfer-" + token transfer_keys = list() From 9fcc24650ba6ce59d0782347f86602ed92fd2190 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:39:55 +0100 Subject: [PATCH 22/32] Tighten test --- dask_expr/tests/test_reuse.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index b95bace1a..6d6f35689 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -23,7 +23,10 @@ def df(pdf): def _check_io_nodes(expr, expected): - assert len(list(expr.find_operations(IO))) == expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(IO)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == expected def test_reuse_everything_scalar_and_series(df, pdf): @@ -35,7 +38,7 @@ def test_reuse_everything_scalar_and_series(df, pdf): pdf["new2"] = pdf["x"] + 1 pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] assert_eq(df, pdf) - _check_io_nodes(df.optimize(fuse=False), 1) + _check_io_nodes(df, 1) def test_dont_reuse_reducer(df, pdf): @@ -44,12 +47,12 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.y.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df + df.sum() expected = pdf + pdf.sum() assert_eq(result, expected, check_names=False) # pandas 2.2 bug - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df.replace(1, 5) rhs_1 = result.x + result.y.sum() @@ -60,7 +63,7 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df.replace(1, 5) result["new"] = result.x + result.y.sum() @@ -69,11 +72,11 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 3) + _check_io_nodes(result, 3) result = df.replace(1, 5) result["new"] = result.x + result.sum().dropna().prod() expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) From 391d8f647b5bcca0b6329e46adc4607bc4d1c6c2 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:22:53 +0100 Subject: [PATCH 23/32] Update --- dask_expr/_core.py | 4 ---- dask_expr/io/parquet.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 7983c756a..d0ff38f88 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -112,10 +112,6 @@ def _tree_repr_lines(self, indent=0, recursive=True): op = "" elif is_arraylike(op): op = "" - elif isinstance(op, BranchId): - if op.branch_id == 0: - continue - op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) if self._branch_id.branch_id != 0: header = self._tree_repr_argument_construction( diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 7724cc53a..a1905331d 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -454,7 +454,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): _absorb_projections = True def _tree_repr_argument_construction(self, i, op, header): - if i < len(self._parameters) and self._parameters[i] == "_dataset_info_cache": + if self._parameters[i] == "_dataset_info_cache": # Don't print this, very ugly return header return super()._tree_repr_argument_construction(i, op, header) From 4326a25cb5384a00b1ec36d8b30a4aae9ea9aa2d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:24:08 +0100 Subject: [PATCH 24/32] Update --- dask_expr/_groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 4027b49ff..d19787d85 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -741,7 +741,6 @@ def _lower(self): for param in self._parameters ], *self.by, - self._branch_id, ) if is_dataframe_like(s._meta): c = c[s.columns] From c9e0384901b5d04fefc2bb3a0758467f3b602f8d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:31:33 +0100 Subject: [PATCH 25/32] Update --- dask_expr/io/io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 84d5df59a..e06682e6b 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -416,8 +416,7 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) if isinstance(parent, Projection): - x = super()._simplify_up(parent, dependents) - return x + return super()._simplify_up(parent, dependents) def _divisions(self): return self._divisions_and_locations[0] From d70ba0ff004982f3a25a96c41e10d675cd38d326 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:17:11 +0100 Subject: [PATCH 26/32] Implement shuffle methods as consumers for branch_id --- dask_expr/_collection.py | 16 ++- dask_expr/_core.py | 13 +- dask_expr/_expr.py | 44 ++++--- dask_expr/_groupby.py | 14 ++- dask_expr/_merge.py | 42 +++++-- dask_expr/_merge_asof.py | 1 + dask_expr/_reductions.py | 37 ++++-- dask_expr/_shuffle.py | 27 ++-- dask_expr/tests/_util.py | 11 ++ dask_expr/tests/test_distributed.py | 188 +++++++++++++++++++++++++++- dask_expr/tests/test_reuse.py | 83 +++++++++--- 11 files changed, 412 insertions(+), 64 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index de4dab3b9..bda71f421 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -757,6 +757,7 @@ def shuffle( shuffle_method, options, index_shuffle=on_index, + _branch_id=expr.BranchId(0), ) ) @@ -2764,6 +2765,7 @@ def drop_duplicates( split_every=split_every, shuffle_method=shuffle_method, keep=keep, + _branch_id=expr.BranchId(0), ) ) @@ -3876,7 +3878,15 @@ def unique(self, split_every=None, split_out=True, shuffle_method=None): uniques : Series """ shuffle_method = _get_shuffle_preferring_order(shuffle_method) - return new_collection(Unique(self, split_every, split_out, shuffle_method)) + return new_collection( + Unique( + self, + split_every, + split_out, + shuffle_method, + _branch_id=expr.BranchId(0), + ) + ) @derived_from(pd.Series) def nunique(self, dropna=True, split_every=False, split_out=True): @@ -3908,6 +3918,7 @@ def drop_duplicates( split_every=split_every, shuffle_method=shuffle_method, keep=keep, + _branch_id=expr.BranchId(0), ) ) @@ -4780,6 +4791,7 @@ def merge( shuffle_method=shuffle_method, _npartitions=npartitions, broadcast=broadcast, + _branch_id=expr.BranchId(0), ) ) @@ -4866,7 +4878,7 @@ def merge_asof( from dask_expr._merge_asof import MergeAsof - return new_collection(MergeAsof(left, right, **kwargs)) + return new_collection(MergeAsof(left, right, **kwargs, _branch_id=expr.BranchId(0))) def from_map( diff --git a/dask_expr/_core.py b/dask_expr/_core.py index f757386d2..248c186cc 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -47,8 +47,10 @@ class Expr: _parameters = [] _defaults = {} _instances = weakref.WeakValueDictionary() + _branch_id_required = False def __new__(cls, *args, _branch_id=None, **kwargs): + cls._check_branch_id_given(args, _branch_id) operands = list(args) if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): _branch_id = operands.pop(-1) @@ -71,6 +73,15 @@ def __new__(cls, *args, _branch_id=None, **kwargs): Expr._instances[_name] = inst return inst + @classmethod + def _check_branch_id_given(cls, args, _branch_id): + if not cls._branch_id_required: + return + operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + assert _branch_id is not None, "BranchId not found" + def _tune_down(self): return None @@ -601,7 +612,7 @@ def _substitute(self, old, new, _seen): new_exprs.append(operand) if update: # Only recreate if something changed - return type(self)(*new_exprs) + return type(self)(*new_exprs, _branch_id=self._branch_id) else: _seen.add(self._name) return self diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index d318017f8..f975dfd5e 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -422,7 +422,9 @@ def dtypes(self): def _filter_simplification(self, parent, predicate=None): if predicate is None: predicate = parent.predicate.substitute(self, self.frame) - return type(self)(self.frame[predicate], *self.operands[1:]) + return type(self)( + self.frame[predicate], *self.operands[1:], _branch_id=self._branch_id + ) class Literal(Expr): @@ -1822,7 +1824,7 @@ class Filter(Blockwise): def _simplify_up(self, parent, dependents): if isinstance(self.predicate, Or): result = rewrite_filters(self.predicate) - if result._name != self.predicate._name: + if result._dep_name != self.predicate._dep_name: return type(parent)( type(self)(self.frame, result), *parent.operands[1:] ) @@ -2088,7 +2090,7 @@ def _simplify_up(self, parent, dependents): ): parents = [ p().columns - for p in dependents[self._name] + for p in dependents[self._dep_name] if p() is not None and not isinstance(p(), Filter) ] predicate = None @@ -2119,7 +2121,7 @@ def _simplify_up(self, parent, dependents): return if all( isinstance(d(), Projection) and d().operand("columns") == col - for d in dependents[self._name] + for d in dependents[self._dep_name] ): return type(self)(self.frame, True, self.name) return @@ -3216,7 +3218,13 @@ def _lower(self): from dask_expr._shuffle import RearrangeByColumn args = [ - RearrangeByColumn(df, None, npartitions, index_shuffle=True) + RearrangeByColumn( + df, + None, + npartitions, + index_shuffle=True, + _branch_id=self._branch_id, + ) if isinstance(df, Expr) else df for df in self.operands @@ -3538,9 +3546,9 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non seen = set() for p in parents: - if p._name in seen: + if p._dep_name in seen: continue - seen.add(p._name) + seen.add(p._dep_name) column_union.extend(p._projection_columns) @@ -3597,8 +3605,8 @@ def plain_column_projection(expr, parent, dependents, additional_columns=None): def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True): - parents = [x() for x in dependents[expr._name] if x() is not None] - filters = {e._name for e in parents if isinstance(e, Filter)} + parents = [x() for x in dependents[expr._dep_name] if x() is not None] + filters = {e._dep_name for e in parents if isinstance(e, Filter)} if len(filters) != 1: # Don't push down if not exactly one Filter return False @@ -3606,7 +3614,7 @@ def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True) return True # We have to see if the non-filter ops are all exclusively part of the predicates - others = {e._name for e in parents if not isinstance(e, Filter)} + others = {e._dep_name for e in parents if not isinstance(e, Filter)} return _check_dependents_are_predicates( expr, others, parent, dependents, allow_reduction ) @@ -3641,7 +3649,7 @@ def _get_predicate_components(predicate, components, type_=Or): def _convert_mapping(components): - return dict(zip([e._name for e in components], components)) + return dict(zip([e._dep_name for e in components], components)) def _replace_common_or_components(expr, or_components): @@ -3692,19 +3700,21 @@ def _check_dependents_are_predicates( # Walk down the predicate side from the filter to see if we can arrive at # other_names without hitting an expression that has other dependents that # are not part of the predicate, see test_filter_pushdown_unavailable - allowed_expressions = {parent._name} + allowed_expressions = {parent._dep_name} stack = parent.dependencies() seen = set() while stack: e = stack.pop() - if expr._name == e._name: + if expr._dep_name == e._dep_name: continue - if e._name in seen: + if e._dep_name in seen: continue - seen.add(e._name) + seen.add(e._dep_name) - e_dependents = {x()._name for x in dependents[e._name] if x() is not None} + e_dependents = { + x()._dep_name for x in dependents[e._dep_name] if x() is not None + } if not allow_reduction: if isinstance(e, (ApplyConcatApply, TreeReduce, ShuffleReduce)): @@ -3715,7 +3725,7 @@ def _check_dependents_are_predicates( continue return False - allowed_expressions.add(e._name) + allowed_expressions.add(e._dep_name) stack.extend(e.dependencies()) return other_names.issubset(allowed_expressions) diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 11c61818c..3b26b7b02 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -52,6 +52,7 @@ from dask.utils import M, apply, derived_from, is_index_like from dask_expr._collection import FrameBase, Index, Series, new_collection +from dask_expr._core import BranchId from dask_expr._expr import ( Assign, Blockwise, @@ -867,6 +868,7 @@ class GroupByApply(Expr, GroupByBase): "group_keys": True, "shuffle_method": None, } + _branch_id_required = True @functools.cached_property def grp_func(self): @@ -878,6 +880,9 @@ def _meta(self): return make_meta(self.operand("meta"), parent_meta=self.frame._meta) return _meta_apply_transform(self, self.grp_func) + def _reuse_down(self): + return + def _divisions(self): if self.need_to_shuffle: return (None,) * (self.frame.npartitions + 1) @@ -925,6 +930,7 @@ def get_map_columns(df): [map_columns.get(c, c) for c in cols], df.npartitions, method=self.shuffle_method, + _branch_id=self._branch_id, ) if unmap_columns: @@ -951,6 +957,7 @@ def get_map_columns(df): map_columns.get(self.by[0], self.by[0]), self.npartitions, method=self.shuffle_method, + _branch_id=self._branch_id, ) if unmap_columns: @@ -1252,7 +1259,9 @@ def groupby_projection(expr, parent, dependents): if columns == expr.frame.columns: return return type(parent)( - type(expr)(expr.frame[columns], *expr.operands[1:]), + type(expr)( + expr.frame[columns], *expr.operands[1:], _branch_id=expr._branch_id + ), *parent.operands[1:], ) return @@ -1945,6 +1954,7 @@ def apply(self, func, *args, meta=no_default, shuffle_method=None, **kwargs): kwargs, shuffle_method, *self.by, + BranchId(0), ) ) @@ -1964,6 +1974,7 @@ def _transform_like_op( kwargs, shuffle_method, *self.by, + BranchId(0), ) ) @@ -2060,6 +2071,7 @@ def median( shuffle_method, split_every, *self.by, + BranchId(0), ) ) if split_out is not True: diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 3e47e1049..521a1175d 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -10,7 +10,7 @@ merge_chunk, ) from dask.dataframe.shuffle import partitioning_index -from dask.utils import apply, get_default_shuffle_method +from dask.utils import apply, funcname, get_default_shuffle_method from toolz import merge_sorted, unique from dask_expr._expr import ( # noqa: F401 @@ -82,6 +82,13 @@ class Merge(Expr): "_npartitions": None, "broadcast": None, } + _branch_id_required = True + + @functools.cached_property + def _dep_name(self): + return ( + funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + ) @property def _filter_passthrough(self): @@ -133,19 +140,19 @@ def _get_original_predicate_columns(self, predicate): seen = set() while stack: e = stack.pop() - if self._name == e._name: + if self._dep_name == e._dep_name: continue - if e._name in seen: + if e._dep_name in seen: continue - seen.add(e._name) + seen.add(e._dep_name) if isinstance(e, _DelayedExpr): continue dependencies = e.dependencies() stack.extend(dependencies) - if any(d._name == self._name for d in dependencies): + if any(d._dep_name == self._dep_name for d in dependencies): predicate_columns.update(e.columns) return predicate_columns @@ -309,6 +316,9 @@ def merge_indexed_right(self): self.right_index or _contains_index_name(self.right, self.right_on) ) and self.right.known_divisions + def _reuse_down(self): + return + def _lower(self): # Lower from an abstract expression left = self.left @@ -366,12 +376,14 @@ def _lower(self): left, shuffle_left_on, npartitions_out=left.npartitions, + _branch_id=self._branch_id, ) else: right = RearrangeByColumn( right, shuffle_right_on, npartitions_out=right.npartitions, + _branch_id=self._branch_id, ) return BroadcastJoin( @@ -404,6 +416,7 @@ def _lower(self): shuffle_left_on=shuffle_left_on, shuffle_right_on=shuffle_right_on, _npartitions=self.operand("_npartitions"), + _branch_id=self._branch_id, ) if shuffle_left_on: @@ -414,6 +427,7 @@ def _lower(self): npartitions_out=self._npartitions, method=shuffle_method, index_shuffle=left_index, + _branch_id=self._branch_id, ) if shuffle_right_on: @@ -424,6 +438,7 @@ def _lower(self): npartitions_out=self._npartitions, method=shuffle_method, index_shuffle=right_index, + _branch_id=self._branch_id, ) # Blockwise merge @@ -466,7 +481,9 @@ def _simplify_up(self, parent, dependents): if new_right is self.right and new_left is self.left: # don't drop the filter return - return type(self)(new_left, new_right, *self.operands[2:]) + return type(self)( + new_left, new_right, *self.operands[2:], _branch_id=self._branch_id + ) if isinstance(parent, (Projection, Index)): # Reorder the column projection to # occur before the Merge @@ -518,7 +535,10 @@ def _simplify_up(self, parent, dependents): right.columns ): result = type(self)( - left[project_left], right[project_right], *self.operands[2:] + left[project_left], + right[project_right], + *self.operands[2:], + _branch_id=self._branch_id, ) if parent_columns is None: return type(parent)(result) @@ -569,7 +589,7 @@ def _layer(self) -> dict: # Include self._name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of # required coordination and complexity introduced through dynamic clusters. - self._name, + self._dep_name, self.left._name, self.shuffle_left_on, self.left_index, @@ -578,7 +598,7 @@ def _layer(self) -> dict: # Include self._name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of # required coordination and complexity introduced through dynamic clusters. - self._name, + self._dep_name, self.right._name, self.shuffle_right_on, self.right_index, @@ -672,6 +692,7 @@ class BroadcastJoin(Merge, PartitionsFiltered): "indicator": False, "_partitions": None, } + _branch_id_required = False def _divisions(self): if self.broadcast_side == "left": @@ -806,6 +827,7 @@ class BlockwiseMerge(Merge, Blockwise): """ is_broadcast_join = False + _branch_id_required = False def _divisions(self): if self.left.npartitions == self.right.npartitions: @@ -863,6 +885,7 @@ def _lower(self): how=self.how, left_index=True, right_index=True, + _branch_id=self._branch_id, ) return self._recursive_join(self.frames) @@ -878,6 +901,7 @@ def _recursive_join(self, frames): how="outer", left_index=True, right_index=True, + _branch_id=self._branch_id, ) midx = len(frames) // 2 diff --git a/dask_expr/_merge_asof.py b/dask_expr/_merge_asof.py index 7efcc13a3..2820ef267 100644 --- a/dask_expr/_merge_asof.py +++ b/dask_expr/_merge_asof.py @@ -40,6 +40,7 @@ class MergeAsof(Merge): "allow_exact_matches": True, "direction": "backward", } + _branch_id_required = False @functools.cached_property def _kwargs(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index e94bae0cb..1f2c02fdc 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -131,6 +131,8 @@ class ShuffleReduce(Expr): ApplyConcatApply """ + _branch_id_required = True + _parameters = [ "frame", "kind", @@ -225,6 +227,7 @@ def _lower(self): ignore_index=ignore_index, index_shuffle=not split_by_index and self.shuffle_by_index, method=self.shuffle_method, + _branch_id=self._branch_id, ) # Unmap column names if necessary @@ -506,39 +509,53 @@ def _lower(self): shuffle_by_index=getattr(self, "shuffle_by_index", None), shuffle_method=getattr(self, "shuffle_method", None), ignore_index=getattr(self, "ignore_index", True), + _branch_id=self._branch_id, ) def _substitute_branch_id(self, branch_id): + if self.should_shuffle: + # We are lowering into a Shuffle, so we are a consumer ourselves and + # we have to consume the branch_id of our parents + return super()._substitute_branch_id(branch_id) return self def _reuse_down(self): if self._branch_id.branch_id != 0: return - from dask_expr._shuffle import Shuffle + if self.should_shuffle: + # We are lowering into a Shuffle, so we are a consumer ourselves + return + + from dask_expr._groupby import GroupByApply + from dask_expr._merge import Merge + from dask_expr._shuffle import ShuffleBase from dask_expr.io import IO seen = set() stack = self.dependencies() - counter, found_io = 1, False + counter, found_consumer = 1, False while stack: node = stack.pop() - if node._name in seen: + if node._dep_name in seen: continue - seen.add(node._name) + seen.add(node._dep_name) if isinstance(node, ApplyConcatApply): - counter += 1 + if node.should_shuffle: + found_consumer = True + else: + counter += 1 continue - if isinstance(node, (IO, Shuffle)): - found_io = True + if isinstance(node, (IO, ShuffleBase, GroupByApply, Merge)): + found_consumer = True continue stack.extend(node.dependencies()) - if not found_io: + if not found_consumer: return b_id = BranchId(counter) result = type(self)(*self.operands, b_id) @@ -631,7 +648,9 @@ def _simplify_up(self, parent, dependents): columns = [col for col in self.frame.columns if col in columns] return type(parent)( - type(self)(self.frame[columns], *self.operands[1:]), + type(self)( + self.frame[columns], *self.operands[1:], _branch_id=self._branch_id + ), *parent.operands[1:], ) diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index c53e2abfd..3a21b60dd 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -85,6 +85,7 @@ class ShuffleBase(Expr): } _is_length_preserving = True _filter_passthrough = True + _branch_id_required = True def __str__(self): return f"Shuffle({self._name[-7:]})" @@ -113,9 +114,11 @@ def _simplify_up(self, parent, dependents): if (col in partitioning_index or col in projection) ] if set(new_projection) < set(target.columns): - return type(self)(target[new_projection], *self.operands[1:])[ - parent.operand("columns") - ] + return type(self)( + target[new_projection], + *self.operands[1:], + _branch_id=self._branch_id, + )[parent.operand("columns")] if isinstance( parent, @@ -140,7 +143,8 @@ def _simplify_up(self, parent, dependents): MemoryUsage, ), ): - return type(parent)(self.frame, *parent.operands[1:]) + branch_id = None if not parent.should_shuffle else parent._branch_id + return type(parent)(self.frame, *parent.operands[1:], _branch_id=branch_id) def _layer(self): raise NotImplementedError( @@ -159,14 +163,13 @@ def _divisions(self): return (None,) * (self.npartitions_out + 1) def _reuse_down(self): + # TODO: What to do with task based shuffle? return @functools.cached_property def _dep_name(self): return ( - funcname(type(self)).lower() - + "-" - + _tokenize_deterministic(*self.argument_operands) + funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) ) @@ -306,7 +309,7 @@ def _lower(self): ignore_index, self.method, options, - self._branch_id, + _branch_id=self._branch_id, ) if frame.ndim == 1: # Reduce back to series @@ -517,6 +520,11 @@ def _layer(self): class DiskShuffle(SimpleShuffle): """Disk-based shuffle implementation""" + @functools.cached_property + def _name(self): + # This is only used locally anyway, so don't bother with pipeline breakers + return self._dep_name + @staticmethod def _shuffle_group(df, col, _filter, p): with ensure_cleanup_on_exception(p): @@ -569,6 +577,7 @@ def _layer(self): ) dsk = {} + # Ensure that shuffles with different branch_ids have the same barrier token = self._dep_name.split("-")[-1] _barrier_key = barrier_key(ShuffleId(token)) name = "shuffle-transfer-" + token @@ -1040,6 +1049,7 @@ def _lower(self): ignore_index=self.ignore_index, method=self.shuffle_method, options=self.options, + _branch_id=self._branch_id, ) shuffled = Projection(shuffled, self.frame.columns) return SortValuesBlockwise( @@ -1127,6 +1137,7 @@ def _lower(self): ignore_index=True, method=self.shuffle_method, options=self.options, + _branch_id=self._branch_id, ) shuffled = Projection( shuffled, [c for c in assigned.columns if c != "_partitions"] diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1f24bfda1..b04a93af0 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -5,6 +5,8 @@ from dask import config from dask.dataframe.utils import assert_eq as dd_assert_eq +from dask_expr.io import IO + def _backend_name() -> str: return config.get("dataframe.backend", "pandas") @@ -39,3 +41,12 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs) + + +def _check_consumer_node(expr, expected, consumer_node=IO, branch_id_counter=None): + if branch_id_counter is None: + branch_id_counter = expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(consumer_node)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == branch_id_counter diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 2cf9aa975..0ab6b00df 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -3,8 +3,9 @@ import pytest from dask_expr import from_pandas, map_partitions, merge -from dask_expr._merge import BroadcastJoin -from dask_expr.tests._util import _backend_library +from dask_expr._merge import BroadcastJoin, HashJoinP2P +from dask_expr._shuffle import P2PShuffle +from dask_expr.tests._util import _backend_library, _check_consumer_node distributed = pytest.importorskip("distributed") @@ -354,3 +355,186 @@ def test_func(n): result = await result expected = pd.DataFrame({"a": [4951, 4952, 4953, 4954]}) pd.testing.assert_frame_equal(result, expected) + + +@gen_cluster(client=True) +async def test_p2p_shuffle_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.shuffle("a") + q = q.fillna(100) + q = q.a + q.a.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.fillna(100) + expected = expected.a + expected.a.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + # Check that we have 1 shuffle barrier but 20 p2pshuffle tasks for the output + dsk = q.optimize(fuse=False).dask + keys = list(dsk.keys()) + assert ( + len( + list( + key for key in keys if isinstance(key, str) and "shuffle-barrier" in key + ) + ) + == 1 + ) + assert ( + len( + list( + key for key in keys if isinstance(key, tuple) and "p2pshuffle" in key[0] + ) + ) + == 20 + ) + + +@gen_cluster(client=True) +async def test_groupby_apply_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").apply(lambda x: x) + q = q.fillna(100) + q = q.a + q.a.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").apply(lambda x: x) + expected = expected.fillna(100) + expected = expected.a + expected.a.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_sum_reuse_split_out(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").sum(split_out=True) + q = df + q.b.sum() + # Only one IO node since groupby-shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 1, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").sum() + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_sum_no_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + # no split_out, so we can't reuse the groupby operation + q = df.groupby("a").sum() + q = df + q.b.sum() + # 2 IO Nodes, one for the groupby branch and one for the main branch + _check_consumer_node(q, 2) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").sum() + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_drop_duplicates_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + # no split_out, so we can't reuse the groupby operation + q = df.drop_duplicates(subset="a") + q = df + q.b.sum() + # Only one IO node since drop duplicates-shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 1, P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.drop_duplicates(subset="a") + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + q = df.drop_duplicates(subset="a", split_out=1) + q = df + q.b.sum() + # 2 IO nodes since reducer can't consume + _check_consumer_node(q, 2) + x = c.compute(q) + x = await x + + expected = pdf.drop_duplicates(subset="a") + expected = pdf + expected.b.sum() + pd.testing.assert_frame_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_groupby_ffill_reuse(c, s, a, b): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6] * 10, "b": 2, "e": 2}) + df = from_pandas(pdf, npartitions=10) + q = df.groupby("a").ffill() + q = q.fillna(100) + q = q.b + q.b.sum() + # Only one IO node since shuffle consumes + _check_consumer_node(q, 1) + _check_consumer_node(q, 2, consumer_node=P2PShuffle) + x = c.compute(q) + x = await x + + expected = pdf.groupby("a").ffill() + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + pd.testing.assert_series_equal(x.sort_index(), expected) + + +@gen_cluster(client=True) +async def test_merge_reuse(c, s, a, b): + pdf1 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "b": 1, "c": 1}) + pdf2 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "e": 1, "f": 1}) + + df1 = from_pandas(pdf1, npartitions=3) + df2 = from_pandas(pdf2, npartitions=3) + q = df1.merge(df2) + q = q.fillna(100) + q = q.b + q.b.sum() + _check_consumer_node(q, 2, HashJoinP2P) + # One on either side + _check_consumer_node(q, 2, branch_id_counter=1) + x = c.compute(q) + x = await x + expected = pdf1.merge(pdf2) + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + pd.testing.assert_series_equal(x.reset_index(drop=True), expected) + + # Check that we have 2 shuffle barriers (one for either side) for both merges but 6 + # hashjoinp2p tasks for the output + dsk = q.optimize(fuse=False).dask + keys = list(dsk.keys()) + assert ( + len( + list( + key for key in keys if isinstance(key, str) and "shuffle-barrier" in key + ) + ) + == 2 + ) + assert ( + len( + list( + key + for key in keys + if isinstance(key, tuple) and "hashjoinp2p" in key[0] + ) + ) + == 6 + ) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index 6d6f35689..c92ccea15 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -3,8 +3,9 @@ import pytest from dask_expr import from_pandas -from dask_expr.io import IO -from dask_expr.tests._util import _backend_library, assert_eq +from dask_expr._merge import BlockwiseMerge +from dask_expr._shuffle import DiskShuffle +from dask_expr.tests._util import _backend_library, _check_consumer_node, assert_eq # Set DataFrame backend for this module pd = _backend_library() @@ -22,13 +23,6 @@ def df(pdf): yield from_pandas(pdf, npartitions=10) -def _check_io_nodes(expr, expected): - expr = expr.optimize(fuse=False) - io_nodes = list(expr.find_operations(IO)) - assert len(io_nodes) == expected - assert len({node._branch_id.branch_id for node in io_nodes}) == expected - - def test_reuse_everything_scalar_and_series(df, pdf): df["new"] = 1 df["new2"] = df["x"] + 1 @@ -38,7 +32,7 @@ def test_reuse_everything_scalar_and_series(df, pdf): pdf["new2"] = pdf["x"] + 1 pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] assert_eq(df, pdf) - _check_io_nodes(df, 1) + _check_consumer_node(df, 1) def test_dont_reuse_reducer(df, pdf): @@ -47,12 +41,12 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.y.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df + df.sum() expected = pdf + pdf.sum() assert_eq(result, expected, check_names=False) # pandas 2.2 bug - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) rhs_1 = result.x + result.y.sum() @@ -63,7 +57,7 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) result["new"] = result.x + result.y.sum() @@ -72,11 +66,70 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 3) + _check_consumer_node(result, 3) result = df.replace(1, 5) result["new"] = result.x + result.sum().dropna().prod() expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) + + +def test_disk_shuffle(df, pdf): + q = df.shuffle("a") + q = q.fillna(100) + q = q.a + q.a.sum() + q.optimize(fuse=False).pprint() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.fillna(100) + expected = expected.a + expected.a.sum() + assert_eq(q, expected) + + +def test_groupb_apply_disk_shuffle_reuse(df, pdf): + q = df.groupby("a").apply(lambda x: x) + q = q.fillna(100) + q = q.a + q.a.sum() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.groupby("a").apply(lambda x: x) + expected = expected.fillna(100) + expected = expected.a + expected.a.sum() + assert_eq(q, expected) + + +def test_groupb_ffill_disk_shuffle_reuse(df, pdf): + q = df.groupby("a").ffill() + q = q.fillna(100) + q = q.b + q.b.sum() + # Disk shuffle is not utilizing pipeline breakers + _check_consumer_node(q, 1, consumer_node=DiskShuffle) + _check_consumer_node(q, 1) + expected = pdf.groupby("a").ffill() + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + assert_eq(q, expected) + + +def test_merge_reuse(): + pdf1 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "b": 1, "c": 1}) + pdf2 = pd.DataFrame({"a": [1, 2, 3, 4, 1, 2, 3, 4], "e": 1, "f": 1}) + + df1 = from_pandas(pdf1, npartitions=3) + df2 = from_pandas(pdf2, npartitions=3) + q = df1.merge(df2) + q = q.fillna(100) + q = q.b + q.b.sum() + _check_consumer_node(q, 1, BlockwiseMerge) + # One on either side + _check_consumer_node(q, 2, DiskShuffle, branch_id_counter=1) + _check_consumer_node(q, 2, branch_id_counter=1) + + expected = pdf1.merge(pdf2) + expected = expected.fillna(100) + expected = expected.b + expected.b.sum() + assert_eq(q, expected, check_index=False) From cc120eebeb63056b902e3ade9bfc6493a0b0128e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:21:23 +0100 Subject: [PATCH 27/32] Update --- dask_expr/tests/_util.py | 11 +++++++++++ dask_expr/tests/test_reuse.py | 22 +++++++--------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1f24bfda1..b04a93af0 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -5,6 +5,8 @@ from dask import config from dask.dataframe.utils import assert_eq as dd_assert_eq +from dask_expr.io import IO + def _backend_name() -> str: return config.get("dataframe.backend", "pandas") @@ -39,3 +41,12 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs) + + +def _check_consumer_node(expr, expected, consumer_node=IO, branch_id_counter=None): + if branch_id_counter is None: + branch_id_counter = expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(consumer_node)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == branch_id_counter diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index 6d6f35689..c9d481a36 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -3,8 +3,7 @@ import pytest from dask_expr import from_pandas -from dask_expr.io import IO -from dask_expr.tests._util import _backend_library, assert_eq +from dask_expr.tests._util import _backend_library, _check_consumer_node, assert_eq # Set DataFrame backend for this module pd = _backend_library() @@ -22,13 +21,6 @@ def df(pdf): yield from_pandas(pdf, npartitions=10) -def _check_io_nodes(expr, expected): - expr = expr.optimize(fuse=False) - io_nodes = list(expr.find_operations(IO)) - assert len(io_nodes) == expected - assert len({node._branch_id.branch_id for node in io_nodes}) == expected - - def test_reuse_everything_scalar_and_series(df, pdf): df["new"] = 1 df["new2"] = df["x"] + 1 @@ -38,7 +30,7 @@ def test_reuse_everything_scalar_and_series(df, pdf): pdf["new2"] = pdf["x"] + 1 pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] assert_eq(df, pdf) - _check_io_nodes(df, 1) + _check_consumer_node(df, 1) def test_dont_reuse_reducer(df, pdf): @@ -47,12 +39,12 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.y.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df + df.sum() expected = pdf + pdf.sum() assert_eq(result, expected, check_names=False) # pandas 2.2 bug - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) rhs_1 = result.x + result.y.sum() @@ -63,7 +55,7 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) result["new"] = result.x + result.y.sum() @@ -72,11 +64,11 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 3) + _check_consumer_node(result, 3) result = df.replace(1, 5) result["new"] = result.x + result.sum().dropna().prod() expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) From 2bbeb2e7e87ac028963cc231590b3c860c02abd3 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:25:14 +0100 Subject: [PATCH 28/32] Remove unnecessary changes --- dask_expr/_collection.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index bda71f421..2d4152e3e 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -2765,7 +2765,6 @@ def drop_duplicates( split_every=split_every, shuffle_method=shuffle_method, keep=keep, - _branch_id=expr.BranchId(0), ) ) @@ -3878,15 +3877,7 @@ def unique(self, split_every=None, split_out=True, shuffle_method=None): uniques : Series """ shuffle_method = _get_shuffle_preferring_order(shuffle_method) - return new_collection( - Unique( - self, - split_every, - split_out, - shuffle_method, - _branch_id=expr.BranchId(0), - ) - ) + return new_collection(Unique(self, split_every, split_out, shuffle_method)) @derived_from(pd.Series) def nunique(self, dropna=True, split_every=False, split_out=True): @@ -3918,7 +3909,6 @@ def drop_duplicates( split_every=split_every, shuffle_method=shuffle_method, keep=keep, - _branch_id=expr.BranchId(0), ) ) From 451dca019f1daeae2b0eb0c5fe44ff9403f70f0e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:48:23 +0100 Subject: [PATCH 29/32] Simplify variable --- dask_expr/_core.py | 14 +++++++++++++- dask_expr/_groupby.py | 1 + dask_expr/_merge.py | 9 ++------- dask_expr/_reductions.py | 5 +++++ dask_expr/_shuffle.py | 10 ++-------- 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 248c186cc..7080caed1 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -48,6 +48,7 @@ class Expr: _defaults = {} _instances = weakref.WeakValueDictionary() _branch_id_required = False + _reuse_consumer = False def __new__(cls, *args, _branch_id=None, **kwargs): cls._check_branch_id_given(args, _branch_id) @@ -486,7 +487,18 @@ def _name(self): @functools.cached_property def _dep_name(self): - return self._name + # The name identifies every expression uniquely. The dependents name + # is used during optimization to capture the dependents of any given + # expression. A reuse consumer will have the same dependents independently + # of the branch_id parameter, since we want to reuse everything that comes + # before us and split branches up everything that is processed after + # us. So we have to ignore the branch_id from tokenization for those + # nodes. + if not self._reuse_consumer: + return self._name + return ( + funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + ) @property def _meta(self): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 3b26b7b02..5661b6035 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -869,6 +869,7 @@ class GroupByApply(Expr, GroupByBase): "shuffle_method": None, } _branch_id_required = True + _reuse_consumer = True @functools.cached_property def grp_func(self): diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 521a1175d..39ec8d405 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -10,7 +10,7 @@ merge_chunk, ) from dask.dataframe.shuffle import partitioning_index -from dask.utils import apply, funcname, get_default_shuffle_method +from dask.utils import apply, get_default_shuffle_method from toolz import merge_sorted, unique from dask_expr._expr import ( # noqa: F401 @@ -83,12 +83,7 @@ class Merge(Expr): "broadcast": None, } _branch_id_required = True - - @functools.cached_property - def _dep_name(self): - return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) - ) + _reuse_consumer = True @property def _filter_passthrough(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 1f2c02fdc..db02d00e6 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -132,6 +132,7 @@ class ShuffleReduce(Expr): """ _branch_id_required = True + _reuse_consumer = True _parameters = [ "frame", @@ -415,6 +416,10 @@ def split_out(self): else: return 1 + @functools.cached_property + def _reuse_consumer(self): + return self.should_shuffle + def _layer(self): # This is an abstract expression raise NotImplementedError() diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 3a21b60dd..07cfd7e17 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -22,7 +22,6 @@ from dask.utils import ( M, digit, - funcname, get_default_shuffle_method, insert, is_index_like, @@ -64,7 +63,7 @@ ValueCounts, ) from dask_expr._repartition import Repartition, RepartitionToFewer -from dask_expr._util import LRU, _convert_to_list, _tokenize_deterministic +from dask_expr._util import LRU, _convert_to_list class ShuffleBase(Expr): @@ -86,6 +85,7 @@ class ShuffleBase(Expr): _is_length_preserving = True _filter_passthrough = True _branch_id_required = True + _reuse_consumer = True def __str__(self): return f"Shuffle({self._name[-7:]})" @@ -166,12 +166,6 @@ def _reuse_down(self): # TODO: What to do with task based shuffle? return - @functools.cached_property - def _dep_name(self): - return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) - ) - class Shuffle(ShuffleBase): """Abstract shuffle class From 150f99ccc53af4f86438fe59d690c3c4a6d84260 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:52:07 +0100 Subject: [PATCH 30/32] Make reuse step easier --- dask_expr/_reductions.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index db02d00e6..6d67e3ff3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -532,9 +532,6 @@ def _reuse_down(self): # We are lowering into a Shuffle, so we are a consumer ourselves return - from dask_expr._groupby import GroupByApply - from dask_expr._merge import Merge - from dask_expr._shuffle import ShuffleBase from dask_expr.io import IO seen = set() @@ -548,16 +545,14 @@ def _reuse_down(self): continue seen.add(node._dep_name) - if isinstance(node, ApplyConcatApply): - if node.should_shuffle: - found_consumer = True - else: - counter += 1 + if isinstance(node, IO) or node._reuse_consumer: + found_consumer = True continue - if isinstance(node, (IO, ShuffleBase, GroupByApply, Merge)): - found_consumer = True + if isinstance(node, ApplyConcatApply): + counter += 1 continue + stack.extend(node.dependencies()) if not found_consumer: From 12432da3719fe7b3c2e264616966bc49072ecd5b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:54:39 +0100 Subject: [PATCH 31/32] Make reuse step easier --- dask_expr/_reductions.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index ebd05e07b..286deafa3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -518,9 +518,11 @@ def _reuse_down(self): if self._branch_id.branch_id != 0: return + from dask_expr.io import IO + seen = set() stack = self.dependencies() - counter, found_io = 1, False + counter, found_consumer = 1, False while stack: node = stack.pop() @@ -529,14 +531,17 @@ def _reuse_down(self): continue seen.add(node._name) + if isinstance(node, IO): + found_consumer = True + continue + if isinstance(node, ApplyConcatApply): counter += 1 continue - deps = node.dependencies() - if not deps: - found_io = True - stack.extend(deps) - if not found_io: + + stack.extend(node.dependencies()) + + if not found_consumer: return b_id = BranchId(counter) result = type(self)(*self.operands, b_id) From 8bc45b92397d07dc0f88acc7ac8e463af0d23c5e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:55:52 +0100 Subject: [PATCH 32/32] Update --- dask_expr/_reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 19f481b8c..b41e624ab 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -518,7 +518,7 @@ def _lower(self): ) def _substitute_branch_id(self, branch_id): - if self.should_shuffle: + if self._reuse_consumer: # We are lowering into a Shuffle, so we are a consumer ourselves and # we have to consume the branch_id of our parents return super()._substitute_branch_id(branch_id) @@ -528,7 +528,7 @@ def _reuse_down(self): if self._branch_id.branch_id != 0: return - if self.should_shuffle: + if self._reuse_consumer: # We are lowering into a Shuffle, so we are a consumer ourselves return