From fae5c6eada640c1cb109af33a2e6cb10d5a5d568 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 15 Feb 2024 23:24:17 +0100 Subject: [PATCH 01/26] Implement branch_id to limit reuse --- dask_expr/_concat.py | 2 +- dask_expr/_core.py | 34 ++++++++++++++++++--- dask_expr/_cumulative.py | 2 +- dask_expr/_expr.py | 47 ++++++++++++++++++------------ dask_expr/_groupby.py | 5 ++-- dask_expr/_reductions.py | 32 +++++++++++++++++++- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/tests/test_collection.py | 11 +++++-- dask_expr/tests/test_groupby.py | 2 +- 10 files changed, 105 insertions(+), 34 deletions(-) diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index f99503121..fcdb4997d 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -45,7 +45,7 @@ def __str__(self): + ", " + ", ".join( str(param) + "=" + str(operand) - for param, operand in zip(self._parameters, self.operands) + for param, operand in zip(self._parameters[:-1], self.operands) if operand != self._defaults.get(param) ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 3aac656c4..8b4a07406 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -5,6 +5,7 @@ import weakref from collections import defaultdict from collections.abc import Generator +from typing import NamedTuple import dask import pandas as pd @@ -15,6 +16,10 @@ from dask_expr._util import _BackendData, _tokenize_deterministic +class BranchId(NamedTuple): + branch_id: int | None + + def _unpack_collections(o): if isinstance(o, Expr): return o @@ -30,8 +35,13 @@ class Expr: _defaults = {} _instances = weakref.WeakValueDictionary() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, _branch_id=None, **kwargs): operands = list(args) + if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): + _branch_id = operands.pop(-1) + elif _branch_id is None: + _branch_id = BranchId(None) + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) @@ -39,7 +49,8 @@ def __new__(cls, *args, **kwargs): operands.append(cls._defaults[parameter]) assert not kwargs, kwargs inst = object.__new__(cls) - inst.operands = [_unpack_collections(o) for o in operands] + inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] + inst._parameters = cls._parameters + ["_branch_id"] _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -47,6 +58,10 @@ def __new__(cls, *args, **kwargs): Expr._instances[_name] = inst return inst + @functools.cached_property + def argument_operands(self): + return self.operands[:-1] + def _tune_down(self): return None @@ -144,7 +159,7 @@ def _depth(self): def operand(self, key): # Access an operand unambiguously # (e.g. if the key is reserved by a method/property) - return self.operands[type(self)._parameters.index(key)] + return self.operands[self._parameters.index(key)] def dependencies(self): # Dependencies are `Expr` operands only @@ -267,6 +282,11 @@ def rewrite(self, kind: str): return expr + def _push_branch_id(self, parent): + if self._branch_id.branch_id != parent._branch_id.branch_id: + result = type(self)(*self.operands[:-1], parent._branch_id) + return parent.substitute(self, result) + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -303,7 +323,9 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): # Allow children to simplify their parents for child in expr.dependencies(): - out = child._simplify_up(expr, dependents) + out = child._push_branch_id(expr) + if out is None: + out = child._simplify_up(expr, dependents) if out is None: out = expr @@ -419,6 +441,10 @@ def _name(self): def _meta(self): raise NotImplementedError() + @functools.cached_property + def _branch_id(self): + return self.operands[-1] + def __getattr__(self, key): try: return object.__getattribute__(self, key) diff --git a/dask_expr/_cumulative.py b/dask_expr/_cumulative.py index 8d8b638ba..1aedd8bc9 100644 --- a/dask_expr/_cumulative.py +++ b/dask_expr/_cumulative.py @@ -47,7 +47,7 @@ def operation(self): @functools.cached_property def _args(self) -> list: - return self.operands[:-1] + return self.argument_operands[:-1] class TakeLast(Blockwise): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 6d800eab6..670735ba5 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -466,7 +466,7 @@ def _kwargs(self) -> dict: if self._keyword_only: return { p: self.operand(p) - for p in self._parameters + for p in self._parameters[:-1] if p in self._keyword_only and self.operand(p) is not no_default } return {} @@ -475,10 +475,12 @@ def _kwargs(self) -> dict: def _args(self) -> list: if self._keyword_only: args = [ - self.operand(p) for p in self._parameters if p not in self._keyword_only - ] + self.operands[len(self._parameters) :] + self.operand(p) + for p in self._parameters[:-1] + if p not in self._keyword_only + ] + self.argument_operands[len(self._parameters) - 1 :] return args - return self.operands + return self.argument_operands def _broadcast_dep(self, dep: Expr): # Checks if a dependency should be broadcasted to @@ -562,7 +564,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.operands[len(self._parameters) :] + return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _meta(self): @@ -658,7 +660,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _dfs(self): @@ -725,7 +727,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.operands[len(self._parameters) :] + for arg in self.argument_operands[len(self._parameters) - 1 :] ] return _get_meta_map_partitions( args, @@ -737,11 +739,11 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.operands[len(self._parameters) :] + args = [self.frame] + self.argument_operands[len(self._parameters) - 1 :] return calc_divisions_for_align(*args) def _lower(self): - args = [self.frame] + self.operands[len(self._parameters) :] + args = [self.frame] + self.operands[len(self._parameters) - 1 :] args = maybe_align_partitions(*args, divisions=self._divisions()) return MapOverlap( args[0], @@ -792,7 +794,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.operands[len(self._parameters) :] + + self.argument_operands[len(self._parameters) - 1 :] ) @functools.cached_property @@ -800,7 +802,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.operands[len(self._parameters) :] + for arg in self.argument_operands[len(self._parameters) - 1 :] ] return _get_meta_map_partitions( args, @@ -1094,7 +1096,11 @@ class Sample(Blockwise): @functools.cached_property def _meta(self): - args = [self.operands[0]._meta] + [self.operands[1][0]] + self.operands[2:] + args = ( + [self.operands[0]._meta] + + [self.operands[1][0]] + + self.argument_operands[2:] + ) return self.operation(*args) def _task(self, index: int): @@ -1696,11 +1702,11 @@ class Assign(Elemwise): @functools.cached_property def keys(self): - return self.operands[1::2] + return self.argument_operands[1::2] @functools.cached_property def vals(self): - return self.operands[2::2] + return self.argument_operands[2::2] @functools.cached_property def _meta(self): @@ -1725,7 +1731,7 @@ def _simplify_down(self): if self._check_for_previously_created_column(self.frame): # don't squash if we are using a column that was previously created return - return Assign(*self.frame.operands, *self.operands[1:]) + return Assign(*self.frame.argument_operands, *self.operands[1:]) def _check_for_previously_created_column(self, child): input_columns = [] @@ -1753,6 +1759,7 @@ def _simplify_up(self, parent, dependents): for k, v in zip(self.keys, self.vals): if k in columns: new_args.extend([k, v]) + new_args.append(self._branch_id) else: new_args = self.operands[1:] @@ -1779,12 +1786,12 @@ class CaseWhen(Elemwise): @functools.cached_property def caselist(self): - c = self.operands[1:] + c = self.argument_operands[1:] return [(c[i], c[i + 1]) for i in range(0, len(c), 2)] @functools.cached_property def _meta(self): - c = self.operands[1:] + c = self.argument_operands[1:] caselist = [ ( meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i], @@ -3308,7 +3315,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _dfs(self): @@ -3323,7 +3330,9 @@ def _meta(self): def _lower(self): args = maybe_align_partitions(*self.args, divisions=self._divisions()) dfs = [x for x in args if isinstance(x, Expr) and x.ndim > 0] - return UFuncElemwise(dfs[0], self.func, self._meta, False, self.kwargs, *args) + return UFuncElemwise( + dfs[0], self.func, self._meta, False, self.kwargs, *args, self._branch_id + ) class Fused(Blockwise): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index faed278b3..9827092ce 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def levels(self): @@ -738,9 +738,10 @@ def _lower(self): self.operand(param) if param not in ("chunk_kwargs", "aggregate_kwargs") else {} - for param in self._parameters + for param in self._parameters[:-1] ], *self.by, + self._branch_id, ) if is_dataframe_like(s._meta): c = c[s.columns] diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 3b562cfc3..3993512ff 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -26,6 +26,7 @@ from dask.utils import M, apply, funcname from dask_expr._concat import Concat +from dask_expr._core import BranchId from dask_expr._expr import ( Blockwise, Expr, @@ -97,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): @@ -507,6 +508,35 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) + def _push_branch_id(self, parent): + return + + def _simplify_down(self): + if self._branch_id.branch_id is not None: + return + + seen = set() + stack = self.dependencies() + counter, found_io = 1, False + + while stack: + node = stack.pop() + + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, ApplyConcatApply): + counter += 1 + continue + deps = node.dependencies() + if not deps: + found_io = True + stack.extend(deps) + if not found_io: + return + return type(self)(*self.operands[:-1], BranchId(counter)) + class Unique(ApplyConcatApply): _parameters = ["frame", "split_every", "split_out", "shuffle_method"] diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index 97748c991..dbf99b727 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.operands[len(self._parameters) :] + return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 9036edf69..26ee48140 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.operands[len(self._parameters) :] + return self.argument_operands[len(self._parameters) - 1 :] @functools.cached_property def _meta(self): diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 3404f213d..9d8948ef2 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -942,7 +942,7 @@ def test_repr(df): s = (df["x"] + 1).sum(skipna=False).expr assert '["x"]' in str(s) or "['x']" in str(s) assert "+ 1" in str(s) - assert "sum(skipna=False)" in str(s) + assert "sum(skipna=False" in str(s) @xfail_gpu("combine_first not supported by cudf") @@ -1924,7 +1924,9 @@ def test_assign_squash_together(df, pdf): pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) - assert "Assign: a=1, b=2" in [line for line in result.expr._tree_repr_lines()] + assert "Assign: a=1, b=2, ranchId(branch_id=0=" in [ + line for line in result.expr._tree_repr_lines() + ] def test_are_co_aligned(pdf, df): @@ -2412,7 +2414,10 @@ def test_filter_optimize_condition(): def test_scalar_repr(df): result = repr(df.size) - assert result == "" + assert ( + result + == "" + ) def test_reset_index_filter_pushdown(df): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 44e7236fc..c8e78afaf 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -34,7 +34,7 @@ def test_groupby_unsupported_by(pdf, df): @pytest.mark.parametrize("split_every", [None, 5]) @pytest.mark.parametrize( "api", - ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], + ["mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], ) @pytest.mark.parametrize( "numeric_only", From d598734b317b87bd83c3d56b2148743e00f67f0d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 00:23:35 +0100 Subject: [PATCH 02/26] Update --- dask_expr/_concat.py | 2 +- dask_expr/_core.py | 11 +++++------ dask_expr/_expr.py | 28 +++++++++++++--------------- dask_expr/_groupby.py | 4 ++-- dask_expr/_reductions.py | 2 +- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/tests/test_groupby.py | 15 ++++++++------- 8 files changed, 32 insertions(+), 34 deletions(-) diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index fcdb4997d..f99503121 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -45,7 +45,7 @@ def __str__(self): + ", " + ", ".join( str(param) + "=" + str(operand) - for param, operand in zip(self._parameters[:-1], self.operands) + for param, operand in zip(self._parameters, self.operands) if operand != self._defaults.get(param) ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 8b4a07406..756b3098c 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -50,7 +50,6 @@ def __new__(cls, *args, _branch_id=None, **kwargs): assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] - inst._parameters = cls._parameters + ["_branch_id"] _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -62,6 +61,10 @@ def __new__(cls, *args, _branch_id=None, **kwargs): def argument_operands(self): return self.operands[:-1] + @functools.cached_property + def _branch_id(self): + return self.operands[-1] + def _tune_down(self): return None @@ -159,7 +162,7 @@ def _depth(self): def operand(self, key): # Access an operand unambiguously # (e.g. if the key is reserved by a method/property) - return self.operands[self._parameters.index(key)] + return self.operands[type(self)._parameters.index(key)] def dependencies(self): # Dependencies are `Expr` operands only @@ -441,10 +444,6 @@ def _name(self): def _meta(self): raise NotImplementedError() - @functools.cached_property - def _branch_id(self): - return self.operands[-1] - def __getattr__(self, key): try: return object.__getattribute__(self, key) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 670735ba5..2d7e4afb1 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -466,7 +466,7 @@ def _kwargs(self) -> dict: if self._keyword_only: return { p: self.operand(p) - for p in self._parameters[:-1] + for p in self._parameters if p in self._keyword_only and self.operand(p) is not no_default } return {} @@ -475,10 +475,8 @@ def _kwargs(self) -> dict: def _args(self) -> list: if self._keyword_only: args = [ - self.operand(p) - for p in self._parameters[:-1] - if p not in self._keyword_only - ] + self.argument_operands[len(self._parameters) - 1 :] + self.operand(p) for p in self._parameters if p not in self._keyword_only + ] + self.argument_operands[len(self._parameters) :] return args return self.argument_operands @@ -564,7 +562,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + return [self.frame] + self.argument_operands[len(self._parameters) :] @functools.cached_property def _meta(self): @@ -660,7 +658,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -727,7 +725,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) - 1 :] + for arg in self.argument_operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -739,11 +737,11 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + args = [self.frame] + self.argument_operands[len(self._parameters) :] return calc_divisions_for_align(*args) def _lower(self): - args = [self.frame] + self.operands[len(self._parameters) - 1 :] + args = [self.frame] + self.operands[len(self._parameters) :] args = maybe_align_partitions(*args, divisions=self._divisions()) return MapOverlap( args[0], @@ -794,7 +792,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.argument_operands[len(self._parameters) - 1 :] + + self.argument_operands[len(self._parameters) :] ) @functools.cached_property @@ -802,7 +800,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) - 1 :] + for arg in self.argument_operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -1908,7 +1906,7 @@ def _simplify_down(self): and self._meta.ndim == self.frame._meta.ndim ): # TODO: we should get more precise around Expr.columns types - return self.frame + return type(self.frame)(*self.frame.argument_operands, self._branch_id) if isinstance(self.frame, Projection): # df[a][b] a = self.frame.operand("columns") @@ -1922,7 +1920,7 @@ def _simplify_down(self): else: assert b in a - return self.frame.frame[b] + return type(self)(self.frame.frame, b, *self.operands[2:]) class Index(Elemwise): @@ -3315,7 +3313,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _dfs(self): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 9827092ce..2dc3b9440 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def levels(self): @@ -738,7 +738,7 @@ def _lower(self): self.operand(param) if param not in ("chunk_kwargs", "aggregate_kwargs") else {} - for param in self._parameters[:-1] + for param in self._parameters ], *self.by, self._branch_id, diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 3993512ff..629d88c82 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -98,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index dbf99b727..cae2a5436 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.argument_operands[len(self._parameters) - 1 :] + return [self.frame] + self.argument_operands[len(self._parameters) :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 26ee48140..70881b9c0 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.argument_operands[len(self._parameters) - 1 :] + return self.argument_operands[len(self._parameters) :] @functools.cached_property def _meta(self): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index c8e78afaf..62b0e1b81 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -34,7 +34,7 @@ def test_groupby_unsupported_by(pdf, df): @pytest.mark.parametrize("split_every", [None, 5]) @pytest.mark.parametrize( "api", - ["mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], + ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], ) @pytest.mark.parametrize( "numeric_only", @@ -95,18 +95,19 @@ def test_groupby_numeric(pdf, df, api, numeric_only, split_every): def test_groupby_reduction_optimize(pdf, df): df = df.replace(1, 5) - agg = df.groupby(df.x).y.sum() - expected_query = df[["x", "y"]] - expected_query = expected_query.groupby(expected_query.x).y.sum() - assert agg.optimize()._name == expected_query.optimize()._name - expect = pdf.replace(1, 5).groupby(["x"]).y.sum() - assert_eq(agg, expect) + # agg = df.groupby(df.x).y.sum() + # expected_query = df[["x", "y"]] + # expected_query = expected_query.groupby(expected_query.x).y.sum() + # assert agg.optimize()._name == expected_query.optimize()._name + # expect = pdf.replace(1, 5).groupby(["x"]).y.sum() + # assert_eq(agg, expect) df2 = df[["y"]] agg = df2.groupby(df.x).y.sum() ops = [ op for op in agg.expr.optimize(fuse=False).walk() if isinstance(op, FromPandas) ] + agg.simplify().pprint() assert len(ops) == 1 assert ops[0].columns == ["x", "y"] From d88270d20f41db6d77f0f76a13b2229df074e65c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 01:57:44 +0100 Subject: [PATCH 03/26] Fix delayed --- dask_expr/_expr.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 15f510fc8..e268820b7 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -53,6 +53,7 @@ from tlz import merge_sorted, partition, unique from dask_expr import _core as core +from dask_expr._core import BranchId from dask_expr._util import ( _calc_maybe_new_divisions, _convert_to_list, @@ -2719,9 +2720,11 @@ class _DelayedExpr(Expr): # TODO _parameters = ["obj"] - def __init__(self, obj): + def __init__(self, obj, _branch_id=None): self.obj = obj - self.operands = [obj] + if _branch_id is None: + _branch_id = BranchId(None) + self.operands = [obj, _branch_id] def __str__(self): return f"{type(self).__name__}({str(self.obj)})" From 948cd83189f21b617f19814cdc2d106f1fc70537 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 14:38:52 +0100 Subject: [PATCH 04/26] Update --- dask_expr/_core.py | 35 +++++++++++++++++++++++------- dask_expr/_expr.py | 26 +++++++++++++++++----- dask_expr/_reductions.py | 16 ++++++++++---- dask_expr/io/parquet.py | 2 +- dask_expr/tests/test_collection.py | 4 +--- dask_expr/tests/test_shuffle.py | 4 ++-- 6 files changed, 63 insertions(+), 24 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 756b3098c..040714186 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -40,7 +40,7 @@ def __new__(cls, *args, _branch_id=None, **kwargs): if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): _branch_id = operands.pop(-1) elif _branch_id is None: - _branch_id = BranchId(None) + _branch_id = BranchId(0) for parameter in cls._parameters[len(operands) :]: try: @@ -125,6 +125,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): op = "" elif is_arraylike(op): op = "" + elif isinstance(op, BranchId): + if op.branch_id == 0: + continue + op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) lines = [header] + lines @@ -285,10 +289,27 @@ def rewrite(self, kind: str): return expr - def _push_branch_id(self, parent): - if self._branch_id.branch_id != parent._branch_id.branch_id: - result = type(self)(*self.operands[:-1], parent._branch_id) - return parent.substitute(self, result) + def _reuse_up(self, parent): + return + + def _reuse_down(self): + if not self.dependencies(): + return + return self._bubble_branch_id_down() + + def _bubble_branch_id_down(self): + b_id = self._branch_id + if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): + ops = [ + op._substitute_branch_id(b_id) if isinstance(op, Expr) else op + for op in self.argument_operands + ] + return type(self)(*ops) + + def _substitute_branch_id(self, branch_id): + if self._branch_id.branch_id != 0: + return self + return type(self)(*self.argument_operands, branch_id) def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -326,9 +347,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): # Allow children to simplify their parents for child in expr.dependencies(): - out = child._push_branch_id(expr) - if out is None: - out = child._simplify_up(expr, dependents) + out = child._simplify_up(expr, dependents) if out is None: out = expr diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index e268820b7..10475f198 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1757,9 +1757,8 @@ def _simplify_up(self, parent, dependents): for k, v in zip(self.keys, self.vals): if k in columns: new_args.extend([k, v]) - new_args.append(self._branch_id) else: - new_args = self.operands[1:] + new_args = self.argument_operands[1:] columns = [col for col in self.frame.columns if col in cols] return type(parent)( @@ -2723,7 +2722,7 @@ class _DelayedExpr(Expr): def __init__(self, obj, _branch_id=None): self.obj = obj if _branch_id is None: - _branch_id = BranchId(None) + _branch_id = BranchId(0) self.operands = [obj, _branch_id] def __str__(self): @@ -2752,7 +2751,9 @@ def normalize_expression(expr): return expr._name -def optimize(expr: Expr, fuse: bool = True) -> Expr: +def optimize( + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = True +) -> Expr: """High level query optimization This leverages three optimization passes: @@ -2766,15 +2767,28 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr: Input expression to optimize fuse: whether or not to turn on blockwise fusion + common_subplan_elimination : bool + whether we want to reuse common subplans that are found in the graph and + are used in self-joins or similar which require all data be held in memory + at some point. Only set this to false if your dataset fits into memory. See Also -------- simplify optimize_blockwise_fusion """ + result = expr + while True: + if common_subplan_elimination: + out = result.rewrite("reuse") + else: + out = result + out = out.simplify() + if out._name == result._name or not common_subplan_elimination: + break + result = out - # Simplify - result = expr.simplify() + result = out # Manipulate Expression to make it more efficient result = result.rewrite(kind="tune") diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index c699e903a..6a1ac5cf3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -508,11 +508,14 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) - def _push_branch_id(self, parent): + def _reuse_up(self, parent): return - def _simplify_down(self): - if self._branch_id.branch_id is not None: + def _substitute_branch_id(self, branch_id): + return self + + def _reuse_down(self): + if self._branch_id.branch_id != 0: return seen = set() @@ -535,7 +538,12 @@ def _simplify_down(self): stack.extend(deps) if not found_io: return - return type(self)(*self.operands[:-1], BranchId(counter)) + b_id = BranchId(counter) + result = type(self)(*self.argument_operands, b_id) + out = result._bubble_branch_id_down() + if out is None: + return result + return type(out)(*out.argument_operands, b_id) class Unique(ApplyConcatApply): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3d..472f58ec7 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -454,7 +454,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): _absorb_projections = True def _tree_repr_argument_construction(self, i, op, header): - if self._parameters[i] == "_dataset_info_cache": + if i < len(self._parameters) and self._parameters[i] == "_dataset_info_cache": # Don't print this, very ugly return header return super()._tree_repr_argument_construction(i, op, header) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 9d8948ef2..02024420d 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -1924,9 +1924,7 @@ def test_assign_squash_together(df, pdf): pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) - assert "Assign: a=1, b=2, ranchId(branch_id=0=" in [ - line for line in result.expr._tree_repr_lines() - ] + assert "Assign: a=1, b=2" in [line for line in result.expr._tree_repr_lines()] def test_are_co_aligned(pdf, df): diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index c7b977284..45338c2c5 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum()._name + assert df.shuffle("x").sum().simplify()._name == df.sum().simplify()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -716,7 +716,7 @@ def test_sort_values_avoid_overeager_filter_pushdown(meth): pdf1 = pd.DataFrame({"a": [4, 2, 3], "b": [1, 2, 3]}) df = from_pandas(pdf1, npartitions=2) df = getattr(df, meth)("a") - df = df[df.b > 2] + df.b.sum() + df = df[df.b > 2] + df[df.b > 1] result = df.simplify() assert isinstance(result.expr.left, Filter) assert isinstance(result.expr.left.frame, BaseSetIndexSortValues) From 045bbef7f3fa53dda837f87105f59d91e71a658d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:43:03 +0100 Subject: [PATCH 05/26] Update --- dask_expr/tests/test_collection.py | 5 +---- dask_expr/tests/test_groupby.py | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 02024420d..e60debc12 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2412,10 +2412,7 @@ def test_filter_optimize_condition(): def test_scalar_repr(df): result = repr(df.size) - assert ( - result - == "" - ) + assert result == "" def test_reset_index_filter_pushdown(df): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 62b0e1b81..be0d21df6 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -95,12 +95,12 @@ def test_groupby_numeric(pdf, df, api, numeric_only, split_every): def test_groupby_reduction_optimize(pdf, df): df = df.replace(1, 5) - # agg = df.groupby(df.x).y.sum() - # expected_query = df[["x", "y"]] - # expected_query = expected_query.groupby(expected_query.x).y.sum() - # assert agg.optimize()._name == expected_query.optimize()._name - # expect = pdf.replace(1, 5).groupby(["x"]).y.sum() - # assert_eq(agg, expect) + agg = df.groupby(df.x).y.sum() + expected_query = df[["x", "y"]] + expected_query = expected_query.groupby(expected_query.x).y.sum() + assert agg.optimize()._name == expected_query.optimize()._name + expect = pdf.replace(1, 5).groupby(["x"]).y.sum() + assert_eq(agg, expect) df2 = df[["y"]] agg = df2.groupby(df.x).y.sum() From 93e0d28d32f423c8c99c0975b7552a33aaf449a5 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:45:44 +0100 Subject: [PATCH 06/26] Add cache --- dask_expr/_core.py | 8 ++++++-- dask_expr/_expr.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 040714186..ba53e9953 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -225,7 +225,7 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def rewrite(self, kind: str): + def rewrite(self, kind: str, cache): """Rewrite an expression This leverages the ``._{kind}_down`` and ``._{kind}_up`` @@ -238,6 +238,9 @@ def rewrite(self, kind: str): changed: whether or not any change occured """ + if self._name in cache: + return cache[self._name] + expr = self down_name = f"_{kind}_down" up_name = f"_{kind}_up" @@ -274,7 +277,8 @@ def rewrite(self, kind: str): changed = False for operand in expr.operands: if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) + new = operand.rewrite(kind=kind, cache=cache) + cache[operand._name] = new if new._name != operand._name: changed = True else: diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 10475f198..ce8e4a42d 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2780,7 +2780,7 @@ def optimize( result = expr while True: if common_subplan_elimination: - out = result.rewrite("reuse") + out = result.rewrite("reuse", cache={}) else: out = result out = out.simplify() @@ -2791,13 +2791,13 @@ def optimize( result = out # Manipulate Expression to make it more efficient - result = result.rewrite(kind="tune") + result = result.rewrite(kind="tune", cache={}) # Lower result = result.lower_completely() # Cull - result = result.rewrite(kind="cull") + result = result.rewrite(kind="cull", cache={}) # Final graph-specific optimizations if fuse: From 7ddda990b0e67ee91addaaf5a9e80f9168637bf8 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:57:33 +0100 Subject: [PATCH 07/26] Enhance tests --- dask_expr/tests/test_collection.py | 42 +++++++++++++++++++----------- dask_expr/tests/test_shuffle.py | 10 +++++-- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index e60debc12..1fa8b6114 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -507,7 +507,7 @@ def test_diff(pdf, df, axis, periods): if axis in ("columns", 1): assert actual._name == actual.simplify()._name else: - assert actual.simplify()._name == expected.simplify()._name + assert actual.optimize()._name == expected.optimize()._name @pytest.mark.parametrize( @@ -1163,8 +1163,8 @@ def test_tail_repartition(df): def test_projection_stacking(df): result = df[["x", "y"]]["x"] - optimized = result.simplify() - expected = df["x"].simplify() + optimized = result.optimize() + expected = df["x"].optimize() assert optimized._name == expected._name @@ -1877,8 +1877,8 @@ def test_assign_simplify(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x", "new"]].simplify() - expected = df2[["x"]].assign(new=df2.x > 1).simplify() + result = df[["x", "new"]].optimize() + expected = df2[["x"]].assign(new=df2.x > 1).optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1889,8 +1889,8 @@ def test_assign_simplify_new_column_not_needed(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df[["x"]].simplify() - expected = df2[["x"]].simplify() + result = df[["x"]].optimize() + expected = df2[["x"]].optimize() assert result._name == expected._name pdf["new"] = pdf.x > 1 @@ -1901,8 +1901,8 @@ def test_assign_simplify_series(pdf): df = from_pandas(pdf) df2 = from_pandas(pdf) df["new"] = df.x > 1 - result = df.new.simplify() - expected = df2[[]].assign(new=df2.x > 1).new.simplify() + result = df.new.optimize() + expected = df2[[]].assign(new=df2.x > 1).new.optimize() assert result._name == expected._name @@ -1920,7 +1920,16 @@ def test_assign_squash_together(df, pdf): df["a"] = 1 df["b"] = 2 result = df.simplify() - assert len([x for x in list(result.expr.walk()) if isinstance(x, expr.Assign)]) == 1 + assert ( + len( + [ + x + for x in list(df.optimize(fuse=False).expr.walk()) + if isinstance(x, expr.Assign) + ] + ) + == 1 + ) pdf["a"] = 1 pdf["b"] = 2 assert_eq(df, pdf) @@ -1970,10 +1979,10 @@ def test_astype_categories(df): assert_eq(result.y._meta.cat.categories, pd.Index([UNKNOWN_CATEGORIES])) -def test_drop_simplify(df): +def test_drop_optimize(df): q = df.drop(columns=["x"])[["y"]] - result = q.simplify() - expected = df[["y"]].simplify() + result = q.optimize() + expected = df[["y"]].optimize() assert result._name == expected._name @@ -2061,6 +2070,7 @@ def test_filter_pushdown_unavailable(df): result = df[df.x > 5] + df.x.sum() result = result[["x"]] expected = df[["x"]][df.x > 5] + df.x.sum() + assert result.optimize()._name == expected.optimize()._name assert result.simplify()._name == expected.simplify()._name @@ -2073,6 +2083,7 @@ def test_filter_pushdown(df, pdf): df = df.rename_axis(index="hello") result = df[df.x > 5].simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name pdf["z"] = 1 df = from_pandas(pdf, npartitions=10) @@ -2081,6 +2092,7 @@ def test_filter_pushdown(df, pdf): df_opt = df[["x", "y"]] expected = df_opt[df_opt.x > 5].rename_axis(index="hello").simplify() assert result._name == expected._name + assert result.optimize()._name == expected.optimize()._name def test_shape(df, pdf): @@ -2430,13 +2442,13 @@ def test_reset_index_filter_pushdown(df): result = q[q > 5] expected = df["x"] expected = expected[expected > 5].reset_index(drop=True) - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name q = df.x.reset_index() result = q[q.x > 5] expected = df["x"] expected = expected[expected > 5].reset_index() - assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name def test_astype_filter_pushdown(df, pdf): diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 45338c2c5..3c6d6c6f3 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -137,7 +137,7 @@ def test_shuffle_column_projection(df): def test_shuffle_reductions(df): - assert df.shuffle("x").sum().simplify()._name == df.sum().simplify()._name + assert df.shuffle("x").sum().optimize()._name == df.sum().optimize()._name @pytest.mark.xfail(reason="Shuffle can't see the reduction through the Projection") @@ -264,7 +264,7 @@ def test_set_index_repartition(df, pdf): assert_eq(result, pdf.set_index("x")) -def test_set_index_simplify(df, pdf): +def test_set_index_optimize(df, pdf): q = df.set_index("x")["y"].optimize(fuse=False) expected = df[["x", "y"]].set_index("x")["y"].optimize(fuse=False) assert q._name == expected._name @@ -697,18 +697,21 @@ def test_shuffle_filter_pushdown(pdf, meth): result = result[result.x > 5.0] expected = getattr(df[df.x > 5.0], meth)("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x") result = result[result.x > 5.0][["x", "y"]] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = getattr(df, meth)("x")[["x", "y"]] result = result[result.x > 5.0] expected = df[["x", "y"]] expected = getattr(expected[expected.x > 5.0], meth)("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name @pytest.mark.parametrize("meth", ["set_index", "sort_values"]) @@ -729,15 +732,18 @@ def test_set_index_filter_pushdown(): result = result[result.y == 1] expected = df[df.y == 1].set_index("x") assert result.simplify()._name == expected._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x") result = result[result.y == 1][["y"]] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name result = df.set_index("x")[["y"]] result = result[result.y == 1] expected = df[["x", "y"]] expected = expected[expected.y == 1].set_index("x") assert result.simplify()._name == expected.simplify()._name + assert result.optimize()._name == expected.optimize()._name From 8c2d97726e0ddda6f24001f825bda14698966351 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 17:25:54 +0100 Subject: [PATCH 08/26] Add tests --- dask_expr/tests/test_reuse.py | 82 +++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 dask_expr/tests/test_reuse.py diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py new file mode 100644 index 000000000..b6c2378f8 --- /dev/null +++ b/dask_expr/tests/test_reuse.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import pytest + +from dask_expr import from_pandas +from dask_expr.io import IO +from dask_expr.tests._util import _backend_library, assert_eq + +# Set DataFrame backend for this module +pd = _backend_library() + + +@pytest.fixture +def pdf(): + pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 1, "c": 1}) + pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions + yield pdf + + +@pytest.fixture +def df(pdf): + yield from_pandas(pdf, npartitions=10) + + +def _check_io_nodes(expr, expected): + assert len(list(expr.find_operations(IO))) == expected + + +def test_reuse_everything_scalar_and_series(df, pdf): + df["new"] = 1 + df["new2"] = df["x"] + 1 + df["new3"] = df.x[df.x > 1] + df.x[df.x > 2] + + pdf["new"] = 1 + pdf["new2"] = pdf["x"] + 1 + pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] + assert_eq(df, pdf) + _check_io_nodes(df.optimize(fuse=False), 1) + + +def test_dont_reuse_reducer(df, pdf): + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + assert_eq(result, expected) + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df + df.sum() + expected = pdf + pdf.sum() + assert_eq(result, expected) + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df.replace(1, 5) + rhs_1 = result.x + result.y.sum() + rhs_2 = result.b + result.a.sum() + result["new"] = rhs_1 + result["new2"] = rhs_2 + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 2) + + result = df.replace(1, 5) + result["new"] = result.x + result.y.sum() + result["new2"] = result.b + result.a.sum() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.y.sum() + expected["new2"] = expected.b + expected.a.sum() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 3) + + result = df.replace(1, 5) + result["new"] = result.x + result.sum().dropna().prod() + expected = pdf.replace(1, 5) + expected["new"] = expected.x + expected.sum().dropna().prod() + assert_eq(result, expected) + result.optimize(fuse=False).pprint() + _check_io_nodes(result.optimize(fuse=False), 2) From 7184bcfbf6c0215bedf6ac5764302a429025a6c4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:10:50 +0100 Subject: [PATCH 09/26] Update --- dask_expr/_expr.py | 8 +++----- dask_expr/tests/test_reuse.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index ce8e4a42d..c9e566140 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1905,7 +1905,7 @@ def _simplify_down(self): and self._meta.ndim == self.frame._meta.ndim ): # TODO: we should get more precise around Expr.columns types - return type(self.frame)(*self.frame.argument_operands, self._branch_id) + return self.frame if isinstance(self.frame, Projection): # df[a][b] a = self.frame.operand("columns") @@ -1919,7 +1919,7 @@ def _simplify_down(self): else: assert b in a - return type(self)(self.frame.frame, b, *self.operands[2:]) + return self.frame.frame[b] class Index(Elemwise): @@ -3344,9 +3344,7 @@ def _meta(self): def _lower(self): args = maybe_align_partitions(*self.args, divisions=self._divisions()) dfs = [x for x in args if isinstance(x, Expr) and x.ndim > 0] - return UFuncElemwise( - dfs[0], self.func, self._meta, False, self.kwargs, *args, self._branch_id - ) + return UFuncElemwise(dfs[0], self.func, self._meta, False, self.kwargs, *args) class Fused(Blockwise): diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index b6c2378f8..d1da88bca 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -60,7 +60,6 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 2) result = df.replace(1, 5) @@ -70,7 +69,6 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 3) result = df.replace(1, 5) @@ -78,5 +76,4 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - result.optimize(fuse=False).pprint() _check_io_nodes(result.optimize(fuse=False), 2) From 5ac93949dde3ddf3836e2bec391af795a9f65acf Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:18:34 +0100 Subject: [PATCH 10/26] Update --- dask_expr/tests/test_reuse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index d1da88bca..850b6a493 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,6 +48,7 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() + result.optimize().pprint() assert_eq(result, expected) _check_io_nodes(result.optimize(fuse=False), 2) From fb2aa9f0df223682674d0c21d35ef5df1de4b47b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:22:21 +0100 Subject: [PATCH 11/26] Update --- dask_expr/_reductions.py | 2 +- dask_expr/tests/test_reuse.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 6a1ac5cf3..db65ba5f9 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -347,7 +347,7 @@ def _layer(self): @property def _meta(self): - return self.operand("_meta") + return make_meta(self.operand("_meta")) def _divisions(self): return (None, None) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index 850b6a493..d1da88bca 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,7 +48,6 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() - result.optimize().pprint() assert_eq(result, expected) _check_io_nodes(result.optimize(fuse=False), 2) From 061de6f1ef136e25047628024d6a8f2cd6887be4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:26:02 +0100 Subject: [PATCH 12/26] Update --- dask_expr/_reductions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index db65ba5f9..6a1ac5cf3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -347,7 +347,7 @@ def _layer(self): @property def _meta(self): - return make_meta(self.operand("_meta")) + return self.operand("_meta") def _divisions(self): return (None, None) From e4865905c1c040489cce93d76cc8f32ea2d18326 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Feb 2024 19:29:37 +0100 Subject: [PATCH 13/26] Update --- dask_expr/tests/test_reuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index d1da88bca..b95bace1a 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -48,7 +48,7 @@ def test_dont_reuse_reducer(df, pdf): result = df + df.sum() expected = pdf + pdf.sum() - assert_eq(result, expected) + assert_eq(result, expected, check_names=False) # pandas 2.2 bug _check_io_nodes(result.optimize(fuse=False), 2) result = df.replace(1, 5) From d28f906f10d316c869a1326e3a75d51809bb84c4 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 11:57:20 +0100 Subject: [PATCH 14/26] Update _core.py --- dask_expr/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index ba53e9953..2d82aae79 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -17,7 +17,7 @@ class BranchId(NamedTuple): - branch_id: int | None + branch_id: int def _unpack_collections(o): From 366415a7345ee530f45ee89ef982429dbcc24992 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:05:47 +0100 Subject: [PATCH 15/26] Update test_groupby.py --- dask_expr/tests/test_groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index be0d21df6..44e7236fc 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -107,7 +107,6 @@ def test_groupby_reduction_optimize(pdf, df): ops = [ op for op in agg.expr.optimize(fuse=False).walk() if isinstance(op, FromPandas) ] - agg.simplify().pprint() assert len(ops) == 1 assert ops[0].columns == ["x", "y"] From ee523eaf730a797c7d4bc33df31d9e0c5ef41e83 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:11:15 +0100 Subject: [PATCH 16/26] Update --- dask_expr/_expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index c9e566140..7ef0b8737 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2752,7 +2752,7 @@ def normalize_expression(expr): def optimize( - expr: Expr, fuse: bool = True, common_subplan_elimination: bool = True + expr: Expr, fuse: bool = True, common_subplan_elimination: bool = False ) -> Expr: """High level query optimization @@ -2767,7 +2767,7 @@ def optimize( Input expression to optimize fuse: whether or not to turn on blockwise fusion - common_subplan_elimination : bool + common_subplan_elimination : bool, default False whether we want to reuse common subplans that are found in the graph and are used in self-joins or similar which require all data be held in memory at some point. Only set this to false if your dataset fits into memory. @@ -2779,12 +2779,12 @@ def optimize( """ result = expr while True: - if common_subplan_elimination: + if not common_subplan_elimination: out = result.rewrite("reuse", cache={}) else: out = result out = out.simplify() - if out._name == result._name or not common_subplan_elimination: + if out._name == result._name or common_subplan_elimination: break result = out From 369c142c38360fa9b1178f5af1a277e6dea768d0 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:11:59 +0100 Subject: [PATCH 17/26] Update --- dask_expr/_expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 7ef0b8737..b80e7c378 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2770,7 +2770,7 @@ def optimize( common_subplan_elimination : bool, default False whether we want to reuse common subplans that are found in the graph and are used in self-joins or similar which require all data be held in memory - at some point. Only set this to false if your dataset fits into memory. + at some point. Only set this to true if your dataset fits into memory. See Also -------- From 7379a01da6cdc65cf4e658525e9281498da4ccf6 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:27:58 +0100 Subject: [PATCH 18/26] Remove argument_operands --- dask_expr/_core.py | 19 +++++++----------- dask_expr/_cumulative.py | 2 +- dask_expr/_expr.py | 40 +++++++++++++++++--------------------- dask_expr/_groupby.py | 2 +- dask_expr/_reductions.py | 8 ++++---- dask_expr/_str_accessor.py | 2 +- dask_expr/io/_delayed.py | 2 +- dask_expr/io/io.py | 20 ++++++++++++++----- dask_expr/io/parquet.py | 2 +- 9 files changed, 49 insertions(+), 48 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index bfd7c2e1c..4a43926d4 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -49,7 +49,8 @@ def __new__(cls, *args, _branch_id=None, **kwargs): operands.append(cls._defaults[parameter]) assert not kwargs, kwargs inst = object.__new__(cls) - inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id] + inst.operands = [_unpack_collections(o) for o in operands] + inst._branch_id = _branch_id _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -57,14 +58,6 @@ def __new__(cls, *args, _branch_id=None, **kwargs): Expr._instances[_name] = inst return inst - @functools.cached_property - def argument_operands(self): - return self.operands[:-1] - - @functools.cached_property - def _branch_id(self): - return self.operands[-1] - def _tune_down(self): return None @@ -300,14 +293,14 @@ def _bubble_branch_id_down(self): if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): ops = [ op._substitute_branch_id(b_id) if isinstance(op, Expr) else op - for op in self.argument_operands + for op in self.operands ] return type(self)(*ops) def _substitute_branch_id(self, branch_id): if self._branch_id.branch_id != 0: return self - return type(self)(*self.argument_operands, branch_id) + return type(self)(*self.operands, branch_id) def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression @@ -454,7 +447,9 @@ def _lower(self): @functools.cached_property def _name(self): return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) @property diff --git a/dask_expr/_cumulative.py b/dask_expr/_cumulative.py index 1aedd8bc9..8d8b638ba 100644 --- a/dask_expr/_cumulative.py +++ b/dask_expr/_cumulative.py @@ -47,7 +47,7 @@ def operation(self): @functools.cached_property def _args(self) -> list: - return self.argument_operands[:-1] + return self.operands[:-1] class TakeLast(Blockwise): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index dd594a287..94c09899c 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -477,9 +477,9 @@ def _args(self) -> list: if self._keyword_only: args = [ self.operand(p) for p in self._parameters if p not in self._keyword_only - ] + self.argument_operands[len(self._parameters) :] + ] + self.operands[len(self._parameters) :] return args - return self.argument_operands + return self.operands def _broadcast_dep(self, dep: Expr): # Checks if a dependency should be broadcasted to @@ -503,7 +503,7 @@ def _name(self): head = funcname(self.operation) else: head = funcname(type(self)).lower() - return head + "-" + _tokenize_deterministic(*self.operands) + return head + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def _blockwise_arg(self, arg, i): """Return a Blockwise-task argument""" @@ -563,7 +563,7 @@ def _broadcast_dep(self, dep: Expr): @property def args(self): - return [self.frame] + self.argument_operands[len(self._parameters) :] + return [self.frame] + self.operands[len(self._parameters) :] @functools.cached_property def _meta(self): @@ -659,7 +659,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -726,7 +726,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) :] + for arg in self.operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -738,7 +738,7 @@ def _meta(self): ) def _divisions(self): - args = [self.frame] + self.argument_operands[len(self._parameters) :] + args = [self.frame] + self.operands[len(self._parameters) :] return calc_divisions_for_align(*args) def _lower(self): @@ -793,7 +793,7 @@ def args(self): return ( [self.frame] + [self.func, self.before, self.after] - + self.argument_operands[len(self._parameters) :] + + self.operands[len(self._parameters) :] ) @functools.cached_property @@ -801,7 +801,7 @@ def _meta(self): meta = self.operand("meta") args = [self.frame._meta] + [ arg._meta if isinstance(arg, Expr) else arg - for arg in self.argument_operands[len(self._parameters) :] + for arg in self.operands[len(self._parameters) :] ] return _get_meta_map_partitions( args, @@ -1095,11 +1095,7 @@ class Sample(Blockwise): @functools.cached_property def _meta(self): - args = ( - [self.operands[0]._meta] - + [self.operands[1][0]] - + self.argument_operands[2:] - ) + args = [self.operands[0]._meta] + [self.operands[1][0]] + self.operands[2:] return self.operation(*args) def _task(self, index: int): @@ -1701,11 +1697,11 @@ class Assign(Elemwise): @functools.cached_property def keys(self): - return self.argument_operands[1::2] + return self.operands[1::2] @functools.cached_property def vals(self): - return self.argument_operands[2::2] + return self.operands[2::2] @functools.cached_property def _meta(self): @@ -1730,7 +1726,7 @@ def _simplify_down(self): if self._check_for_previously_created_column(self.frame): # don't squash if we are using a column that was previously created return - return Assign(*self.frame.argument_operands, *self.operands[1:]) + return Assign(*self.frame.operands, *self.operands[1:]) def _check_for_previously_created_column(self, child): input_columns = [] @@ -1758,7 +1754,7 @@ def _simplify_up(self, parent, dependents): if k in columns: new_args.extend([k, v]) else: - new_args = self.argument_operands[1:] + new_args = self.operands[1:] columns = [col for col in self.frame.columns if col in cols] return type(parent)( @@ -1783,12 +1779,12 @@ class CaseWhen(Elemwise): @functools.cached_property def caselist(self): - c = self.argument_operands[1:] + c = self.operands[1:] return [(c[i], c[i + 1]) for i in range(0, len(c), 2)] @functools.cached_property def _meta(self): - c = self.argument_operands[1:] + c = self.operands[1:] caselist = [ ( meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i], @@ -3332,7 +3328,7 @@ def __str__(self): @functools.cached_property def args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _dfs(self): @@ -3424,7 +3420,7 @@ def __str__(self): @functools.cached_property def _name(self): - return f"{str(self)}-{_tokenize_deterministic(self.exprs)}" + return f"{str(self)}-{_tokenize_deterministic(self.exprs, self._branch_id)}" def _divisions(self): return self.exprs[0]._divisions() diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index c6ca9433e..4027b49ff 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -106,7 +106,7 @@ def split_by(self): @functools.cached_property def by(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def levels(self): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 6a1ac5cf3..1eef1eadb 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -98,7 +98,7 @@ class Aggregate(Chunk): @functools.cached_property def aggregate_args(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @staticmethod def _call_with_list_arg(func, *args, **kwargs): @@ -301,7 +301,7 @@ def _name(self): name = funcname(self.combine.__self__).lower() + "-tree" else: name = funcname(self.combine) - return name + "-" + _tokenize_deterministic(*self.operands) + return name + "-" + _tokenize_deterministic(*self.operands, self._branch_id) def __dask_postcompute__(self): return toolz.first, () @@ -539,11 +539,11 @@ def _reuse_down(self): if not found_io: return b_id = BranchId(counter) - result = type(self)(*self.argument_operands, b_id) + result = type(self)(*self.operands, b_id) out = result._bubble_branch_id_down() if out is None: return result - return type(out)(*out.argument_operands, b_id) + return type(out)(*out.operands, b_id) class Unique(ApplyConcatApply): diff --git a/dask_expr/_str_accessor.py b/dask_expr/_str_accessor.py index cae2a5436..97748c991 100644 --- a/dask_expr/_str_accessor.py +++ b/dask_expr/_str_accessor.py @@ -128,7 +128,7 @@ class CatBlockwise(Blockwise): @property def _args(self) -> list: - return [self.frame] + self.argument_operands[len(self._parameters) :] + return [self.frame] + self.operands[len(self._parameters) :] @staticmethod def operation(ser, *args, **kwargs): diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 70881b9c0..9036edf69 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -30,7 +30,7 @@ def dependencies(self): @functools.cached_property def dfs(self): - return self.argument_operands[len(self._parameters) :] + return self.operands[len(self._parameters) :] @functools.cached_property def _meta(self): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 1b6b34fe7..307132d2c 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -48,7 +48,9 @@ def _divisions(self): @functools.cached_property def _name(self): return ( - self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + self.operand("name_prefix") + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) ) def _layer(self): @@ -103,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) @functools.cached_property @@ -173,10 +175,14 @@ def _name(self): return ( funcname(self.func).lower() + "-" - + _tokenize_deterministic(*self.operands) + + _tokenize_deterministic(*self.operands, self._branch_id) ) else: - return self.label + "-" + _tokenize_deterministic(*self.operands) + return ( + self.label + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @functools.cached_property def _meta(self): @@ -448,7 +454,11 @@ class FromPandasDivisions(FromPandas): @functools.cached_property def _name(self): - return "from_pd_divs" + "-" + _tokenize_deterministic(*self.operands) + return ( + "from_pd_divs" + + "-" + + _tokenize_deterministic(*self.operands, self._branch_id) + ) @property def _divisions_and_locations(self): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 472f58ec7..7724cc53a 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -501,7 +501,7 @@ def _name(self): return ( funcname(type(self)).lower() + "-" - + _tokenize_deterministic(self.checksum, *self.operands) + + _tokenize_deterministic(self.checksum, *self.operands, self._branch_id) ) @property From 68e048c49b3d59272c2b5f7bd50115e91186b1e5 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:00:11 +0100 Subject: [PATCH 19/26] Update --- dask_expr/_core.py | 8 +++++--- dask_expr/_expr.py | 3 ++- dask_expr/_reductions.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 4a43926d4..2d5150cf8 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -273,7 +273,7 @@ def rewrite(self, kind: str, cache): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) continue else: break @@ -290,6 +290,8 @@ def _reuse_down(self): def _bubble_branch_id_down(self): b_id = self._branch_id + if b_id.branch_id <= 0: + return if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): ops = [ op._substitute_branch_id(b_id) if isinstance(op, Expr) else op @@ -366,7 +368,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): new_operands.append(new) if changed: - expr = type(expr)(*new_operands) + expr = type(expr)(*new_operands, _branch_id=expr._branch_id) break @@ -411,7 +413,7 @@ def lower_once(self): new_operands.append(new) if changed: - out = type(out)(*new_operands) + out = type(out)(*new_operands, _branch_id=out._branch_id) return out diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 94c09899c..c01e5267f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2722,7 +2722,8 @@ def __init__(self, obj, _branch_id=None): self.obj = obj if _branch_id is None: _branch_id = BranchId(0) - self.operands = [obj, _branch_id] + self._branch_id = _branch_id + self.operands = [obj] def __str__(self): return f"{type(self).__name__}({str(self.obj)})" diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 1eef1eadb..ebd05e07b 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -543,7 +543,7 @@ def _reuse_down(self): out = result._bubble_branch_id_down() if out is None: return result - return type(out)(*out.operands, b_id) + return type(out)(*out.operands, _branch_id=b_id) class Unique(ApplyConcatApply): From f79155ac84d0c2b90cd08e1a4f288d02015dde92 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:56:53 +0100 Subject: [PATCH 20/26] Update --- dask_expr/_core.py | 7 +++++-- dask_expr/io/io.py | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 2d5150cf8..7983c756a 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -117,7 +117,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): continue op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) - + if self._branch_id.branch_id != 0: + header = self._tree_repr_argument_construction( + i + 1, f" branch_id={self._branch_id.branch_id}", header + ) lines = [header] + lines lines = [" " * indent + line for line in lines] @@ -604,7 +607,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr: else: new_operands.append(operand) if changed: - return type(self)(*new_operands) + return type(self)(*new_operands, _branch_id=self._branch_id) return self def _node_label_args(self): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 307132d2c..84d5df59a 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -105,7 +105,7 @@ def _name(self): return ( funcname(type(self.operand("_expr"))).lower() + "-fused-" - + _tokenize_deterministic(*self.operands, self._branch_id) + + _tokenize_deterministic(*self.operands, self._expr._branch_id) ) @functools.cached_property @@ -416,7 +416,8 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) if isinstance(parent, Projection): - return super()._simplify_up(parent, dependents) + x = super()._simplify_up(parent, dependents) + return x def _divisions(self): return self._divisions_and_locations[0] From 9fcc24650ba6ce59d0782347f86602ed92fd2190 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:39:55 +0100 Subject: [PATCH 21/26] Tighten test --- dask_expr/tests/test_reuse.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index b95bace1a..6d6f35689 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -23,7 +23,10 @@ def df(pdf): def _check_io_nodes(expr, expected): - assert len(list(expr.find_operations(IO))) == expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(IO)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == expected def test_reuse_everything_scalar_and_series(df, pdf): @@ -35,7 +38,7 @@ def test_reuse_everything_scalar_and_series(df, pdf): pdf["new2"] = pdf["x"] + 1 pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] assert_eq(df, pdf) - _check_io_nodes(df.optimize(fuse=False), 1) + _check_io_nodes(df, 1) def test_dont_reuse_reducer(df, pdf): @@ -44,12 +47,12 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.y.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df + df.sum() expected = pdf + pdf.sum() assert_eq(result, expected, check_names=False) # pandas 2.2 bug - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df.replace(1, 5) rhs_1 = result.x + result.y.sum() @@ -60,7 +63,7 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) result = df.replace(1, 5) result["new"] = result.x + result.y.sum() @@ -69,11 +72,11 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 3) + _check_io_nodes(result, 3) result = df.replace(1, 5) result["new"] = result.x + result.sum().dropna().prod() expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - _check_io_nodes(result.optimize(fuse=False), 2) + _check_io_nodes(result, 2) From 391d8f647b5bcca0b6329e46adc4607bc4d1c6c2 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:22:53 +0100 Subject: [PATCH 22/26] Update --- dask_expr/_core.py | 4 ---- dask_expr/io/parquet.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 7983c756a..d0ff38f88 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -112,10 +112,6 @@ def _tree_repr_lines(self, indent=0, recursive=True): op = "" elif is_arraylike(op): op = "" - elif isinstance(op, BranchId): - if op.branch_id == 0: - continue - op = f" branch_id={op.branch_id}" header = self._tree_repr_argument_construction(i, op, header) if self._branch_id.branch_id != 0: header = self._tree_repr_argument_construction( diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 7724cc53a..a1905331d 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -454,7 +454,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): _absorb_projections = True def _tree_repr_argument_construction(self, i, op, header): - if i < len(self._parameters) and self._parameters[i] == "_dataset_info_cache": + if self._parameters[i] == "_dataset_info_cache": # Don't print this, very ugly return header return super()._tree_repr_argument_construction(i, op, header) From 4326a25cb5384a00b1ec36d8b30a4aae9ea9aa2d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:24:08 +0100 Subject: [PATCH 23/26] Update --- dask_expr/_groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 4027b49ff..d19787d85 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -741,7 +741,6 @@ def _lower(self): for param in self._parameters ], *self.by, - self._branch_id, ) if is_dataframe_like(s._meta): c = c[s.columns] From c9e0384901b5d04fefc2bb3a0758467f3b602f8d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:31:33 +0100 Subject: [PATCH 24/26] Update --- dask_expr/io/io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 84d5df59a..e06682e6b 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -416,8 +416,7 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) if isinstance(parent, Projection): - x = super()._simplify_up(parent, dependents) - return x + return super()._simplify_up(parent, dependents) def _divisions(self): return self._divisions_and_locations[0] From cc120eebeb63056b902e3ade9bfc6493a0b0128e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:21:23 +0100 Subject: [PATCH 25/26] Update --- dask_expr/tests/_util.py | 11 +++++++++++ dask_expr/tests/test_reuse.py | 22 +++++++--------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1f24bfda1..b04a93af0 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -5,6 +5,8 @@ from dask import config from dask.dataframe.utils import assert_eq as dd_assert_eq +from dask_expr.io import IO + def _backend_name() -> str: return config.get("dataframe.backend", "pandas") @@ -39,3 +41,12 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs) + + +def _check_consumer_node(expr, expected, consumer_node=IO, branch_id_counter=None): + if branch_id_counter is None: + branch_id_counter = expected + expr = expr.optimize(fuse=False) + io_nodes = list(expr.find_operations(consumer_node)) + assert len(io_nodes) == expected + assert len({node._branch_id.branch_id for node in io_nodes}) == branch_id_counter diff --git a/dask_expr/tests/test_reuse.py b/dask_expr/tests/test_reuse.py index 6d6f35689..c9d481a36 100644 --- a/dask_expr/tests/test_reuse.py +++ b/dask_expr/tests/test_reuse.py @@ -3,8 +3,7 @@ import pytest from dask_expr import from_pandas -from dask_expr.io import IO -from dask_expr.tests._util import _backend_library, assert_eq +from dask_expr.tests._util import _backend_library, _check_consumer_node, assert_eq # Set DataFrame backend for this module pd = _backend_library() @@ -22,13 +21,6 @@ def df(pdf): yield from_pandas(pdf, npartitions=10) -def _check_io_nodes(expr, expected): - expr = expr.optimize(fuse=False) - io_nodes = list(expr.find_operations(IO)) - assert len(io_nodes) == expected - assert len({node._branch_id.branch_id for node in io_nodes}) == expected - - def test_reuse_everything_scalar_and_series(df, pdf): df["new"] = 1 df["new2"] = df["x"] + 1 @@ -38,7 +30,7 @@ def test_reuse_everything_scalar_and_series(df, pdf): pdf["new2"] = pdf["x"] + 1 pdf["new3"] = pdf.x[pdf.x > 1] + pdf.x[pdf.x > 2] assert_eq(df, pdf) - _check_io_nodes(df, 1) + _check_consumer_node(df, 1) def test_dont_reuse_reducer(df, pdf): @@ -47,12 +39,12 @@ def test_dont_reuse_reducer(df, pdf): expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.y.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df + df.sum() expected = pdf + pdf.sum() assert_eq(result, expected, check_names=False) # pandas 2.2 bug - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) rhs_1 = result.x + result.y.sum() @@ -63,7 +55,7 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) result = df.replace(1, 5) result["new"] = result.x + result.y.sum() @@ -72,11 +64,11 @@ def test_dont_reuse_reducer(df, pdf): expected["new"] = expected.x + expected.y.sum() expected["new2"] = expected.b + expected.a.sum() assert_eq(result, expected) - _check_io_nodes(result, 3) + _check_consumer_node(result, 3) result = df.replace(1, 5) result["new"] = result.x + result.sum().dropna().prod() expected = pdf.replace(1, 5) expected["new"] = expected.x + expected.sum().dropna().prod() assert_eq(result, expected) - _check_io_nodes(result, 2) + _check_consumer_node(result, 2) From 12432da3719fe7b3c2e264616966bc49072ecd5b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 24 Feb 2024 15:54:39 +0100 Subject: [PATCH 26/26] Make reuse step easier --- dask_expr/_reductions.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index ebd05e07b..286deafa3 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -518,9 +518,11 @@ def _reuse_down(self): if self._branch_id.branch_id != 0: return + from dask_expr.io import IO + seen = set() stack = self.dependencies() - counter, found_io = 1, False + counter, found_consumer = 1, False while stack: node = stack.pop() @@ -529,14 +531,17 @@ def _reuse_down(self): continue seen.add(node._name) + if isinstance(node, IO): + found_consumer = True + continue + if isinstance(node, ApplyConcatApply): counter += 1 continue - deps = node.dependencies() - if not deps: - found_io = True - stack.extend(deps) - if not found_io: + + stack.extend(node.dependencies()) + + if not found_consumer: return b_id = BranchId(counter) result = type(self)(*self.operands, b_id)