From 7cbb6834dba5c7fedb204f1981a1e4e2fb318fd0 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 22 Jan 2024 09:08:34 -0800 Subject: [PATCH 1/7] try improving caching in simplify_once --- dask_expr/_core.py | 67 ++++++++++++++++++++++++---------------------- dask_expr/_expr.py | 4 +-- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 67f70a8b..bb895e33 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -30,6 +30,7 @@ class Expr: _defaults = {} def __init__(self, *args, **kwargs): + self._simplified = {} operands = list(args) for parameter in type(self)._parameters[len(operands) :]: try: @@ -279,47 +280,49 @@ def simplify_once(self, dependents: defaultdict): expr: output expression """ - expr = self - - while True: - out = expr._simplify_down() - if out is None: - out = expr - if not isinstance(out, Expr): - return out - if out._name != expr._name: - expr = out - - # Allow children to simplify their parents - for child in expr.dependencies(): - out = child._simplify_up(expr, dependents) + key = _tokenize_deterministic(sorted(dependents.keys())) + if key not in self._simplified: + expr = self + while True: + out = expr._simplify_down() if out is None: out = expr - if not isinstance(out, Expr): return out - if out is not expr and out._name != expr._name: + if out._name != expr._name: expr = out - break - # Rewrite all of the children - new_operands = [] - changed = False - for operand in expr.operands: - if isinstance(operand, Expr): - new = operand.simplify_once(dependents=dependents) - if new._name != operand._name: - changed = True - else: - new = operand - new_operands.append(new) + # Allow children to simplify their parents + for child in expr.dependencies(): + out = child._simplify_up(expr, dependents) + if out is None: + out = expr - if changed: - expr = type(expr)(*new_operands) + if not isinstance(out, Expr): + return out + if out is not expr and out._name != expr._name: + expr = out + break - break + # Rewrite all of the children + new_operands = [] + changed = False + for operand in expr.operands: + if isinstance(operand, Expr): + new = operand.simplify_once(dependents=dependents) + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) + + if changed: + expr = type(expr)(*new_operands) - return expr + break + self._simplified[key] = expr + + return self._simplified[key] def simplify(self) -> Expr: expr = self diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 1b56800f..717bc5b1 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1768,9 +1768,7 @@ def vals(self): @functools.cached_property def _meta(self): - args = [ - meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args - ] + args = [op._meta if isinstance(op, Expr) else op for op in self._args] return make_meta(self.operation(*args, **self._kwargs)) def _tree_repr_argument_construction(self, i, op, header): From d344895739c6b03ac68b12f809c306a50dc60393 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 22 Jan 2024 09:37:45 -0800 Subject: [PATCH 2/7] add _simplified cache to _DelayedExpr --- dask_expr/_core.py | 69 +++++++++++++++++++++------------------- dask_expr/io/_delayed.py | 1 + 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index bb895e33..073d3370 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -281,48 +281,51 @@ def simplify_once(self, dependents: defaultdict): output expression """ key = _tokenize_deterministic(sorted(dependents.keys())) - if key not in self._simplified: - expr = self - while True: - out = expr._simplify_down() + if key in self._simplified: + return self._simplified[key] + + expr = self + + while True: + out = expr._simplify_down() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out._name != expr._name: + expr = out + + # Allow children to simplify their parents + for child in expr.dependencies(): + out = child._simplify_up(expr, dependents) if out is None: out = expr + if not isinstance(out, Expr): return out - if out._name != expr._name: + if out is not expr and out._name != expr._name: expr = out + break - # Allow children to simplify their parents - for child in expr.dependencies(): - out = child._simplify_up(expr, dependents) - if out is None: - out = expr - - if not isinstance(out, Expr): - return out - if out is not expr and out._name != expr._name: - expr = out - break + # Rewrite all of the children + new_operands = [] + changed = False + for operand in expr.operands: + if isinstance(operand, Expr): + new = operand.simplify_once(dependents=dependents) + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) - # Rewrite all of the children - new_operands = [] - changed = False - for operand in expr.operands: - if isinstance(operand, Expr): - new = operand.simplify_once(dependents=dependents) - if new._name != operand._name: - changed = True - else: - new = operand - new_operands.append(new) - - if changed: - expr = type(expr)(*new_operands) + if changed: + expr = type(expr)(*new_operands) - break - self._simplified[key] = expr + break - return self._simplified[key] + self._simplified[key] = expr + return expr def simplify(self) -> Expr: expr = self diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index cc2aa4ed..4e9f84ea 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -22,6 +22,7 @@ class _DelayedExpr(Expr): # TODO def __init__(self, obj): + self._simplified = {} self.obj = obj self.operands = [obj] From 6e7d0ddcbe842381d2f1d9ea2dddc643be3095a0 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 22 Jan 2024 09:42:11 -0800 Subject: [PATCH 3/7] add comment --- dask_expr/_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 073d3370..33886945 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -280,6 +280,7 @@ def simplify_once(self, dependents: defaultdict): expr: output expression """ + # Check if we've already simplified for these dependents key = _tokenize_deterministic(sorted(dependents.keys())) if key in self._simplified: return self._simplified[key] @@ -324,7 +325,7 @@ def simplify_once(self, dependents: defaultdict): break - self._simplified[key] = expr + self._simplified[key] = expr # Cache the result return expr def simplify(self) -> Expr: From ebcc644a18734a53518f2c2950f8d7f9dd88ed04 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 5 Feb 2024 09:09:19 -0800 Subject: [PATCH 4/7] fix test --- dask_expr/tests/test_shuffle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 63afeca4..3a82d1b4 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -5,6 +5,7 @@ import pytest from dask_expr import from_pandas, new_collection +from dask_expr._core import Expr from dask_expr._expr import Assign, Blockwise from dask_expr._reductions import NFirst, NLast from dask_expr._repartition import RepartitionToFewer @@ -647,6 +648,7 @@ def test_set_index_sort_values_one_partition(pdf): def test_set_index_triggers_calc_when_accessing_divisions(pdf, df): + Expr._instances = {} # Divisions can be cached in the instance divisions_lru.data = OrderedDict() query = df.set_index("x") assert len(divisions_lru.data) == 0 From b07ef8b6de11b649c15f0b3c9ac3376f5c86d3fc Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 5 Feb 2024 09:09:44 -0800 Subject: [PATCH 5/7] formatting --- dask_expr/tests/test_shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 3a82d1b4..6b1e87e4 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -648,7 +648,7 @@ def test_set_index_sort_values_one_partition(pdf): def test_set_index_triggers_calc_when_accessing_divisions(pdf, df): - Expr._instances = {} # Divisions can be cached in the instance + Expr._instances = {} # divisions can be cached in the instance divisions_lru.data = OrderedDict() query = df.set_index("x") assert len(divisions_lru.data) == 0 From 69d8f2b8c70766ae0deb054eb55e797d092b1a66 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 5 Feb 2024 09:46:39 -0800 Subject: [PATCH 6/7] remove Expr-instance-level caching --- dask_expr/_core.py | 19 +++++++++++-------- dask_expr/tests/test_shuffle.py | 2 -- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 2f3aaf7b..940187d8 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -40,7 +40,6 @@ def __new__(cls, *args, **kwargs): assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] - inst._simplified = {} _name = inst._name if _name in Expr._instances: return Expr._instances[_name] @@ -268,7 +267,7 @@ def rewrite(self, kind: str): return expr - def simplify_once(self, dependents: defaultdict): + def simplify_once(self, dependents: defaultdict, simplified: dict): """Simplify an expression This leverages the ``._simplify_down`` and ``._simplify_up`` @@ -279,6 +278,8 @@ def simplify_once(self, dependents: defaultdict): dependents: defaultdict[list] The dependents for every node. + simplified: dict + Cache of simplified expressions for these dependents. Returns ------- @@ -286,9 +287,8 @@ def simplify_once(self, dependents: defaultdict): output expression """ # Check if we've already simplified for these dependents - key = _tokenize_deterministic(sorted(dependents.keys())) - if key in self._simplified: - return self._simplified[key] + if self._name in simplified: + return simplified[self._name] expr = self @@ -320,7 +320,10 @@ def simplify_once(self, dependents: defaultdict): if isinstance(operand, Expr): # Bandaid for now, waiting for Singleton dependents[operand._name].append(weakref.ref(expr)) - new = operand.simplify_once(dependents=dependents) + new = operand.simplify_once( + dependents=dependents, simplified=simplified + ) + simplified[operand._name] = new if new._name != operand._name: changed = True else: @@ -332,14 +335,14 @@ def simplify_once(self, dependents: defaultdict): break - self._simplified = {key: expr} # Cache the last result return expr def simplify(self) -> Expr: expr = self while True: dependents = collect_dependents(expr) - new = expr.simplify_once(dependents=dependents) + simplified = {} + new = expr.simplify_once(dependents=dependents, simplified=simplified) if new._name == expr._name: break expr = new diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 6b1e87e4..63afeca4 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -5,7 +5,6 @@ import pytest from dask_expr import from_pandas, new_collection -from dask_expr._core import Expr from dask_expr._expr import Assign, Blockwise from dask_expr._reductions import NFirst, NLast from dask_expr._repartition import RepartitionToFewer @@ -648,7 +647,6 @@ def test_set_index_sort_values_one_partition(pdf): def test_set_index_triggers_calc_when_accessing_divisions(pdf, df): - Expr._instances = {} # divisions can be cached in the instance divisions_lru.data = OrderedDict() query = df.set_index("x") assert len(divisions_lru.data) == 0 From 8594759f425b55a5763d34d2da4666c64cbdccb6 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 5 Feb 2024 16:19:56 -0800 Subject: [PATCH 7/7] fixup --- dask_expr/_core.py | 3 +-- dask_expr/io/_delayed.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 940187d8..3aac656c 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -341,8 +341,7 @@ def simplify(self) -> Expr: expr = self while True: dependents = collect_dependents(expr) - simplified = {} - new = expr.simplify_once(dependents=dependents, simplified=simplified) + new = expr.simplify_once(dependents=dependents, simplified={}) if new._name == expr._name: break expr = new diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index 9dc24e04..a0f864f7 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -23,7 +23,6 @@ class _DelayedExpr(Expr): _parameters = ["obj"] def __init__(self, obj): - self._simplified = {} self.obj = obj self.operands = [obj]