Skip to content

Commit

Permalink
Add caching to recursive simplify_once calls (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora authored Feb 6, 2024
1 parent 3c1992f commit b5d17ad
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
15 changes: 12 additions & 3 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def rewrite(self, kind: str):

return expr

def simplify_once(self, dependents: defaultdict):
def simplify_once(self, dependents: defaultdict, simplified: dict):
"""Simplify an expression
This leverages the ``._simplify_down`` and ``._simplify_up``
Expand All @@ -278,12 +278,18 @@ def simplify_once(self, dependents: defaultdict):
dependents: defaultdict[list]
The dependents for every node.
simplified: dict
Cache of simplified expressions for these dependents.
Returns
-------
expr:
output expression
"""
# Check if we've already simplified for these dependents
if self._name in simplified:
return simplified[self._name]

expr = self

while True:
Expand Down Expand Up @@ -314,7 +320,10 @@ def simplify_once(self, dependents: defaultdict):
if isinstance(operand, Expr):
# Bandaid for now, waiting for Singleton
dependents[operand._name].append(weakref.ref(expr))
new = operand.simplify_once(dependents=dependents)
new = operand.simplify_once(
dependents=dependents, simplified=simplified
)
simplified[operand._name] = new
if new._name != operand._name:
changed = True
else:
Expand All @@ -332,7 +341,7 @@ def simplify(self) -> Expr:
expr = self
while True:
dependents = collect_dependents(expr)
new = expr.simplify_once(dependents=dependents)
new = expr.simplify_once(dependents=dependents, simplified={})
if new._name == expr._name:
break
expr = new
Expand Down
4 changes: 1 addition & 3 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,9 +1742,7 @@ def vals(self):

@functools.cached_property
def _meta(self):
args = [
meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args
]
args = [op._meta if isinstance(op, Expr) else op for op in self._args]
return make_meta(self.operation(*args, **self._kwargs))

def _tree_repr_argument_construction(self, i, op, header):
Expand Down

0 comments on commit b5d17ad

Please sign in to comment.