Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add branch_id to distinguish between reusable branches and pipeline breakers #883

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import weakref
from collections import defaultdict
from collections.abc import Generator
from typing import NamedTuple

import dask
import pandas as pd
Expand All @@ -15,6 +16,10 @@
from dask_expr._util import _BackendData, _tokenize_deterministic


class BranchId(NamedTuple):
branch_id: int


def _unpack_collections(o):
if isinstance(o, Expr):
return o
Expand All @@ -30,23 +35,36 @@ class Expr:
_defaults = {}
_instances = weakref.WeakValueDictionary()

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, _branch_id=None, **kwargs):
operands = list(args)
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId):
_branch_id = operands.pop(-1)
elif _branch_id is None:
_branch_id = BranchId(0)

for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(cls._defaults[parameter])
assert not kwargs, kwargs
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
inst.operands = [_unpack_collections(o) for o in operands] + [_branch_id]
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]

Expr._instances[_name] = inst
return inst

@functools.cached_property
def argument_operands(self):
return self.operands[:-1]

@functools.cached_property
def _branch_id(self):
return self.operands[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just storing this as an attribute? Why does this need to be an operand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having it in the operands makes sure that it is used for _name, which is what we mostly care about. If that is overridden and forgotten, then the whole thing won't work properly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that. I'm just concerned this is messing with other things since operands are everywhere. If it works I'm fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to "avoid" the kind of confusion I'm thinking about you introduced argument_opearands but this may also be confusing to developers


def _tune_down(self):
return None

Expand Down Expand Up @@ -107,6 +125,10 @@ def _tree_repr_lines(self, indent=0, recursive=True):
op = "<series>"
elif is_arraylike(op):
op = "<array>"
elif isinstance(op, BranchId):
if op.branch_id == 0:
continue
op = f" branch_id={op.branch_id}"
header = self._tree_repr_argument_construction(i, op, header)

lines = [header] + lines
Expand Down Expand Up @@ -203,7 +225,7 @@ def _layer(self) -> dict:

return {(self._name, i): self._task(i) for i in range(self.npartitions)}

def rewrite(self, kind: str):
def rewrite(self, kind: str, cache):
"""Rewrite an expression

This leverages the ``._{kind}_down`` and ``._{kind}_up``
Expand All @@ -216,6 +238,9 @@ def rewrite(self, kind: str):
changed:
whether or not any change occured
"""
if self._name in cache:
return cache[self._name]

expr = self
down_name = f"_{kind}_down"
up_name = f"_{kind}_up"
Expand Down Expand Up @@ -252,7 +277,8 @@ def rewrite(self, kind: str):
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
new = operand.rewrite(kind=kind)
new = operand.rewrite(kind=kind, cache=cache)
cache[operand._name] = new
if new._name != operand._name:
changed = True
else:
Expand All @@ -267,6 +293,28 @@ def rewrite(self, kind: str):

return expr

def _reuse_up(self, parent):
return

def _reuse_down(self):
if not self.dependencies():
return
return self._bubble_branch_id_down()

def _bubble_branch_id_down(self):
b_id = self._branch_id
if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()):
ops = [
op._substitute_branch_id(b_id) if isinstance(op, Expr) else op
for op in self.argument_operands
]
return type(self)(*ops)

def _substitute_branch_id(self, branch_id):
if self._branch_id.branch_id != 0:
return self
return type(self)(*self.argument_operands, branch_id)

def simplify_once(self, dependents: defaultdict, simplified: dict):
"""Simplify an expression

Expand Down
2 changes: 1 addition & 1 deletion dask_expr/_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def operation(self):

@functools.cached_property
def _args(self) -> list:
return self.operands[:-1]
return self.argument_operands[:-1]


class TakeLast(Blockwise):
Expand Down
68 changes: 45 additions & 23 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from tlz import merge_sorted, partition, unique

from dask_expr import _core as core
from dask_expr._core import BranchId
from dask_expr._util import (
_calc_maybe_new_divisions,
_convert_to_list,
Expand Down Expand Up @@ -476,9 +477,9 @@ def _args(self) -> list:
if self._keyword_only:
args = [
self.operand(p) for p in self._parameters if p not in self._keyword_only
] + self.operands[len(self._parameters) :]
] + self.argument_operands[len(self._parameters) :]
return args
return self.operands
return self.argument_operands

def _broadcast_dep(self, dep: Expr):
# Checks if a dependency should be broadcasted to
Expand Down Expand Up @@ -562,7 +563,7 @@ def _broadcast_dep(self, dep: Expr):

@property
def args(self):
return [self.frame] + self.operands[len(self._parameters) :]
return [self.frame] + self.argument_operands[len(self._parameters) :]

@functools.cached_property
def _meta(self):
Expand Down Expand Up @@ -658,7 +659,7 @@ def __str__(self):

@functools.cached_property
def args(self):
return self.operands[len(self._parameters) :]
return self.argument_operands[len(self._parameters) :]

@functools.cached_property
def _dfs(self):
Expand Down Expand Up @@ -725,7 +726,7 @@ def _meta(self):
meta = self.operand("meta")
args = [self.frame._meta] + [
arg._meta if isinstance(arg, Expr) else arg
for arg in self.operands[len(self._parameters) :]
for arg in self.argument_operands[len(self._parameters) :]
]
return _get_meta_map_partitions(
args,
Expand All @@ -737,7 +738,7 @@ def _meta(self):
)

def _divisions(self):
args = [self.frame] + self.operands[len(self._parameters) :]
args = [self.frame] + self.argument_operands[len(self._parameters) :]
return calc_divisions_for_align(*args)

def _lower(self):
Expand Down Expand Up @@ -792,15 +793,15 @@ def args(self):
return (
[self.frame]
+ [self.func, self.before, self.after]
+ self.operands[len(self._parameters) :]
+ self.argument_operands[len(self._parameters) :]
)

@functools.cached_property
def _meta(self):
meta = self.operand("meta")
args = [self.frame._meta] + [
arg._meta if isinstance(arg, Expr) else arg
for arg in self.operands[len(self._parameters) :]
for arg in self.argument_operands[len(self._parameters) :]
]
return _get_meta_map_partitions(
args,
Expand Down Expand Up @@ -1094,7 +1095,11 @@ class Sample(Blockwise):

@functools.cached_property
def _meta(self):
args = [self.operands[0]._meta] + [self.operands[1][0]] + self.operands[2:]
args = (
[self.operands[0]._meta]
+ [self.operands[1][0]]
+ self.argument_operands[2:]
)
return self.operation(*args)

def _task(self, index: int):
Expand Down Expand Up @@ -1696,11 +1701,11 @@ class Assign(Elemwise):

@functools.cached_property
def keys(self):
return self.operands[1::2]
return self.argument_operands[1::2]

@functools.cached_property
def vals(self):
return self.operands[2::2]
return self.argument_operands[2::2]

@functools.cached_property
def _meta(self):
Expand All @@ -1725,7 +1730,7 @@ def _simplify_down(self):
if self._check_for_previously_created_column(self.frame):
# don't squash if we are using a column that was previously created
return
return Assign(*self.frame.operands, *self.operands[1:])
return Assign(*self.frame.argument_operands, *self.operands[1:])

def _check_for_previously_created_column(self, child):
input_columns = []
Expand Down Expand Up @@ -1753,7 +1758,7 @@ def _simplify_up(self, parent, dependents):
if k in columns:
new_args.extend([k, v])
else:
new_args = self.operands[1:]
new_args = self.argument_operands[1:]

columns = [col for col in self.frame.columns if col in cols]
return type(parent)(
Expand All @@ -1778,12 +1783,12 @@ class CaseWhen(Elemwise):

@functools.cached_property
def caselist(self):
c = self.operands[1:]
c = self.argument_operands[1:]
return [(c[i], c[i + 1]) for i in range(0, len(c), 2)]

@functools.cached_property
def _meta(self):
c = self.operands[1:]
c = self.argument_operands[1:]
caselist = [
(
meta_nonempty(c[i]._meta) if isinstance(c[i], Expr) else c[i],
Expand Down Expand Up @@ -2714,9 +2719,11 @@ class _DelayedExpr(Expr):
# TODO
_parameters = ["obj"]

def __init__(self, obj):
def __init__(self, obj, _branch_id=None):
self.obj = obj
self.operands = [obj]
if _branch_id is None:
_branch_id = BranchId(0)
self.operands = [obj, _branch_id]

def __str__(self):
return f"{type(self).__name__}({str(self.obj)})"
Expand Down Expand Up @@ -2744,7 +2751,9 @@ def normalize_expression(expr):
return expr._name


def optimize(expr: Expr, fuse: bool = True) -> Expr:
def optimize(
expr: Expr, fuse: bool = True, common_subplan_elimination: bool = True
) -> Expr:
"""High level query optimization

This leverages three optimization passes:
Expand All @@ -2758,24 +2767,37 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr:
Input expression to optimize
fuse:
whether or not to turn on blockwise fusion
common_subplan_elimination : bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, our default is actually a very radical form of "common subplan elimination". I think this feature is actually quite the opposite and rather something like "replicate common subplan"? I guess other engines are doing it the other way round, aren't they?

Copy link
Collaborator Author

@phofl phofl Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the name from here: https://docs.pola.rs/py-polars/html/reference/lazyframe/api/polars.LazyFrame.explain.html

But I don't have strong feelings either way

And should have read the actual documentation, yes you are correct this should be the other way round

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is hard. @hendrikmakait any thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I'm late to the party. dask-expr approaches CSE from the other end, defaulting to as much CSE as possible. This makes it less intuitive. As mentioned in #896, common_subplan_elimination=False does not mean that we don't do any CSE but instead that we try to reduce it. I'm fine with the argument name for now that you've cleaned it up, let's see if this leads to some confusion in the future.

Copy link
Member

@hendrikmakait hendrikmakait Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding Common Subexpression Elimination vs Common Subplan Elimination: I think Common Subexpression Elimination is more common for data systems and compilers/IR-based systems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly. polars seems to have both subexpression (https://github.com/pola-rs/polars/blob/128803b237dc13d0522c22dbccae1257ae30477e/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs) as well as subplan elimination (https://github.com/pola-rs/polars/blob/128803b237dc13d0522c22dbccae1257ae30477e/crates/polars-plan/src/logical_plan/optimizer/cse.rs).

I'm not entirely sure what the difference is, but from what I understand cse_expr deals with eliminating duplication within a single expressions (a.sum() + b.sum() + a.sum()) and cse deals with what eliminating subgraphs across the entire graph.

That is, cse_expr corresponds to local CSE and cse to global CSE in the link Florian shared.

whether we want to reuse common subplans that are found in the graph and
are used in self-joins or similar which require all data be held in memory
at some point. Only set this to false if your dataset fits into memory.

See Also
--------
simplify
optimize_blockwise_fusion
"""
result = expr
while True:
if common_subplan_elimination:
out = result.rewrite("reuse", cache={})
else:
out = result
out = out.simplify()
if out._name == result._name or not common_subplan_elimination:
break
result = out

# Simplify
result = expr.simplify()
result = out

# Manipulate Expression to make it more efficient
result = result.rewrite(kind="tune")
result = result.rewrite(kind="tune", cache={})

# Lower
result = result.lower_completely()

# Cull
result = result.rewrite(kind="cull")
result = result.rewrite(kind="cull", cache={})

# Final graph-specific optimizations
if fuse:
Expand Down Expand Up @@ -3307,7 +3329,7 @@ def __str__(self):

@functools.cached_property
def args(self):
return self.operands[len(self._parameters) :]
return self.argument_operands[len(self._parameters) :]

@functools.cached_property
def _dfs(self):
Expand Down
3 changes: 2 additions & 1 deletion dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def split_by(self):

@functools.cached_property
def by(self):
return self.operands[len(self._parameters) :]
return self.argument_operands[len(self._parameters) :]

@functools.cached_property
def levels(self):
Expand Down Expand Up @@ -741,6 +741,7 @@ def _lower(self):
for param in self._parameters
],
*self.by,
self._branch_id,
)
if is_dataframe_like(s._meta):
c = c[s.columns]
Expand Down
Loading
Loading