-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DNM: Full branch_id implementation #896
base: main
Are you sure you want to change the base?
Changes from all commits
fae5c6e
d598734
2345cd4
d88270d
948cd83
045bbef
93e0d28
7ddda99
8c2d977
7184bcf
5ac9394
fb2aa9f
061de6f
e486590
d28f906
366415a
ee523ea
369c142
5ee43dd
7379a01
68e048c
f79155a
4801a93
9fcc246
391d8f6
4326a25
1b6b090
c9e0384
5531985
6891478
d70ba0f
cc120ee
8ba433c
9257b72
b665ce1
2bbeb2e
451dca0
150f99c
12432da
9cbcec3
8bc45b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||
import weakref | ||||||
from collections import defaultdict | ||||||
from collections.abc import Generator | ||||||
from typing import TYPE_CHECKING, Literal | ||||||
from typing import TYPE_CHECKING, Literal, NamedTuple | ||||||
|
||||||
import dask | ||||||
import pandas as pd | ||||||
|
@@ -29,6 +29,10 @@ | |||||
] | ||||||
|
||||||
|
||||||
class BranchId(NamedTuple): | ||||||
branch_id: int | ||||||
|
||||||
|
||||||
def _unpack_collections(o): | ||||||
if isinstance(o, Expr): | ||||||
return o | ||||||
|
@@ -43,9 +47,17 @@ class Expr: | |||||
_parameters = [] | ||||||
_defaults = {} | ||||||
_instances = weakref.WeakValueDictionary() | ||||||
_branch_id_required = False | ||||||
_reuse_consumer = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name is ambiguous. Can we come up with something more descriptive? |
||||||
|
||||||
def __new__(cls, *args, **kwargs): | ||||||
def __new__(cls, *args, _branch_id=None, **kwargs): | ||||||
cls._check_branch_id_given(args, _branch_id) | ||||||
operands = list(args) | ||||||
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): | ||||||
_branch_id = operands.pop(-1) | ||||||
elif _branch_id is None: | ||||||
_branch_id = BranchId(0) | ||||||
|
||||||
for parameter in cls._parameters[len(operands) :]: | ||||||
try: | ||||||
operands.append(kwargs.pop(parameter)) | ||||||
|
@@ -54,13 +66,23 @@ def __new__(cls, *args, **kwargs): | |||||
assert not kwargs, kwargs | ||||||
inst = object.__new__(cls) | ||||||
inst.operands = [_unpack_collections(o) for o in operands] | ||||||
inst._branch_id = _branch_id | ||||||
_name = inst._name | ||||||
if _name in Expr._instances: | ||||||
return Expr._instances[_name] | ||||||
|
||||||
Expr._instances[_name] = inst | ||||||
return inst | ||||||
|
||||||
@classmethod | ||||||
def _check_branch_id_given(cls, args, _branch_id): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||
if not cls._branch_id_required: | ||||||
return | ||||||
operands = list(args) | ||||||
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId): | ||||||
_branch_id = operands.pop(-1) | ||||||
assert _branch_id is not None, "BranchId not found" | ||||||
|
||||||
def _tune_down(self): | ||||||
return None | ||||||
|
||||||
|
@@ -116,7 +138,10 @@ def _tree_repr_lines(self, indent=0, recursive=True): | |||||
elif is_arraylike(op): | ||||||
op = "<array>" | ||||||
header = self._tree_repr_argument_construction(i, op, header) | ||||||
|
||||||
if self._branch_id.branch_id != 0: | ||||||
header = self._tree_repr_argument_construction( | ||||||
i + 1, f" branch_id={self._branch_id.branch_id}", header | ||||||
) | ||||||
lines = [header] + lines | ||||||
lines = [" " * indent + line for line in lines] | ||||||
|
||||||
|
@@ -218,7 +243,7 @@ def _layer(self) -> dict: | |||||
|
||||||
return {(self._name, i): self._task(i) for i in range(self.npartitions)} | ||||||
|
||||||
def rewrite(self, kind: str): | ||||||
def rewrite(self, kind: str, cache): | ||||||
"""Rewrite an expression | ||||||
|
||||||
This leverages the ``._{kind}_down`` and ``._{kind}_up`` | ||||||
|
@@ -231,6 +256,9 @@ def rewrite(self, kind: str): | |||||
changed: | ||||||
whether or not any change occured | ||||||
""" | ||||||
if self._name in cache: | ||||||
return cache[self._name] | ||||||
|
||||||
expr = self | ||||||
down_name = f"_{kind}_down" | ||||||
up_name = f"_{kind}_up" | ||||||
|
@@ -267,21 +295,46 @@ def rewrite(self, kind: str): | |||||
changed = False | ||||||
for operand in expr.operands: | ||||||
if isinstance(operand, Expr): | ||||||
new = operand.rewrite(kind=kind) | ||||||
new = operand.rewrite(kind=kind, cache=cache) | ||||||
cache[operand._name] = new | ||||||
if new._name != operand._name: | ||||||
changed = True | ||||||
else: | ||||||
new = operand | ||||||
new_operands.append(new) | ||||||
|
||||||
if changed: | ||||||
expr = type(expr)(*new_operands) | ||||||
expr = type(expr)(*new_operands, _branch_id=expr._branch_id) | ||||||
continue | ||||||
else: | ||||||
break | ||||||
|
||||||
return expr | ||||||
|
||||||
def _reuse_up(self, parent): | ||||||
return | ||||||
|
||||||
def _reuse_down(self): | ||||||
if not self.dependencies(): | ||||||
return | ||||||
return self._bubble_branch_id_down() | ||||||
|
||||||
def _bubble_branch_id_down(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||
b_id = self._branch_id | ||||||
if b_id.branch_id <= 0: | ||||||
return | ||||||
if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()): | ||||||
ops = [ | ||||||
op._substitute_branch_id(b_id) if isinstance(op, Expr) else op | ||||||
for op in self.operands | ||||||
] | ||||||
return type(self)(*ops) | ||||||
|
||||||
def _substitute_branch_id(self, branch_id): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
or something else that highlights the conditionality. |
||||||
if self._branch_id.branch_id != 0: | ||||||
return self | ||||||
return type(self)(*self.operands, branch_id) | ||||||
|
||||||
def simplify_once(self, dependents: defaultdict, simplified: dict): | ||||||
"""Simplify an expression | ||||||
|
||||||
|
@@ -346,7 +399,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): | |||||
new_operands.append(new) | ||||||
|
||||||
if changed: | ||||||
expr = type(expr)(*new_operands) | ||||||
expr = type(expr)(*new_operands, _branch_id=expr._branch_id) | ||||||
|
||||||
break | ||||||
|
||||||
|
@@ -391,7 +444,7 @@ def lower_once(self): | |||||
new_operands.append(new) | ||||||
|
||||||
if changed: | ||||||
out = type(out)(*new_operands) | ||||||
out = type(out)(*new_operands, _branch_id=out._branch_id) | ||||||
|
||||||
return out | ||||||
|
||||||
|
@@ -426,6 +479,23 @@ def _lower(self): | |||||
|
||||||
@functools.cached_property | ||||||
def _name(self): | ||||||
return ( | ||||||
funcname(type(self)).lower() | ||||||
+ "-" | ||||||
+ _tokenize_deterministic(*self.operands, self._branch_id) | ||||||
) | ||||||
|
||||||
@functools.cached_property | ||||||
def _dep_name(self): | ||||||
# The name identifies every expression uniquely. The dependents name | ||||||
# is used during optimization to capture the dependents of any given | ||||||
# expression. A reuse consumer will have the same dependents independently | ||||||
# of the branch_id parameter, since we want to reuse everything that comes | ||||||
# before us and split branches up everything that is processed after | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's a word missing, maybe
Suggested change
|
||||||
# us. So we have to ignore the branch_id from tokenization for those | ||||||
# nodes. | ||||||
if not self._reuse_consumer: | ||||||
return self._name | ||||||
return ( | ||||||
funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) | ||||||
) | ||||||
Comment on lines
+482
to
501
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels prone to errors/inconsistencies when subclassing. Would it make sense to define a property For example, |
||||||
|
@@ -554,7 +624,7 @@ def _substitute(self, old, new, _seen): | |||||
new_exprs.append(operand) | ||||||
|
||||||
if update: # Only recreate if something changed | ||||||
return type(self)(*new_exprs) | ||||||
return type(self)(*new_exprs, _branch_id=self._branch_id) | ||||||
else: | ||||||
_seen.add(self._name) | ||||||
return self | ||||||
|
@@ -580,7 +650,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr: | |||||
else: | ||||||
new_operands.append(operand) | ||||||
if changed: | ||||||
return type(self)(*new_operands) | ||||||
return type(self)(*new_operands, _branch_id=self._branch_id) | ||||||
return self | ||||||
|
||||||
def _node_label_args(self): | ||||||
|
@@ -741,5 +811,5 @@ def collect_dependents(expr) -> defaultdict: | |||||
|
||||||
for dep in node.dependencies(): | ||||||
stack.append(dep) | ||||||
dependents[dep._name].append(weakref.ref(node)) | ||||||
dependents[dep._dep_name].append(weakref.ref(node)) | ||||||
return dependents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you prefer a
NamedTuple
over aNewType('BranchId', int)
here?