Skip to content

Commit

Permalink
Allow expressions to be shipped to the scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 19, 2023
1 parent c770840 commit 93a0fd1
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 9 deletions.
15 changes: 7 additions & 8 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ class FrameBase(DaskMethodsMixin):
)
__dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk)

def __dask_tokenize__(self):
return self.expr._name

def __init__(self, expr):
self._expr = expr

Expand Down Expand Up @@ -262,14 +265,7 @@ def compute(self, fuse=True, combine_similar=True, **kwargs):
return DaskMethodsMixin.compute(out, **kwargs)

def __dask_graph__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_graph__()

def __dask_keys__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_keys__()
return self.expr

def simplify(self):
return new_collection(self.expr.simplify())
Expand All @@ -286,6 +282,9 @@ def optimize(self, combine_similar: bool = True, fuse: bool = True):
def dask(self):
return self.__dask_graph__()

def finalize_compute(self):
return new_collection(Repartition(self.expr, 1))

def __dask_postcompute__(self):
state = new_collection(self.expr.lower_completely())
if type(self) != type(state):
Expand Down
288 changes: 288 additions & 0 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,283 @@ def name(self):
def dtypes(self):
return self._meta.dtypes

@property
def _meta(self):
raise NotImplementedError()

def materialize(self):
"""Traverse expression tree, collect layers"""
stack = [self]
seen = set()
layers = []
while stack:
expr = stack.pop()

if expr._name in seen:
continue
seen.add(expr._name)

layers.append(expr._layer())
for operand in expr.dependencies():
stack.append(operand)

return toolz.merge(layers)

def __dask_keys__(self):
return [(self._name, i) for i in range(self.npartitions)]

def substitute(self, old, new) -> Expr:
"""Substitute a specific term within the expression
Note that replacing non-`Expr` terms may produce
unexpected results, and is not recommended.
Substituting boolean values is not allowed.
Parameters
----------
old:
Old term to find and replace.
new:
New term to replace instances of `old` with.
Examples
--------
>>> (df + 10).substitute(10, 20)
df + 20
"""

# Check if we are replacing a literal
if isinstance(old, Expr):
substitute_literal = False
if self._name == old._name:
return new
else:
substitute_literal = True
if isinstance(old, bool):
raise TypeError("Arguments to `substitute` cannot be bool.")

new_exprs = []
update = False
for operand in self.operands:
if isinstance(operand, Expr):
val = operand.substitute(old, new)
if operand._name != val._name:
update = True
new_exprs.append(val)
elif (
isinstance(self, Fused)
and isinstance(operand, list)
and all(isinstance(op, Expr) for op in operand)
):
# Special handling for `Fused`.
# We make no promise to dive through a
# list operand in general, but NEED to
# do so for the `Fused.exprs` operand.
val = []
for op in operand:
val.append(op.substitute(old, new))
if val[-1]._name != op._name:
update = True
new_exprs.append(val)
elif (
substitute_literal
and not isinstance(operand, bool)
and isinstance(operand, type(old))
and operand == old
):
new_exprs.append(new)
update = True
else:
new_exprs.append(operand)

if update: # Only recreate if something changed
return type(self)(*new_exprs)
return self

def substitute_parameters(self, substitutions: dict) -> Expr:
"""Substitute specific `Expr` parameters
Parameters
----------
substitutions:
Mapping of parameter keys to new values. Keys that
are not found in ``self._parameters`` will be ignored.
"""
if not substitutions:
return self

changed = False
new_operands = []
for i, operand in enumerate(self.operands):
if i < len(self._parameters) and self._parameters[i] in substitutions:
new_operands.append(substitutions[self._parameters[i]])
changed = True
else:
new_operands.append(operand)
if changed:
return type(self)(*new_operands)
return self

def _find_similar_operations(self, root: Expr, ignore: list | None = None):
# Find operations with the same type and operands.
# Parameter keys specified by `ignore` will not be
# included in the operand comparison
alike = [
op for op in root.find_operations(type(self)) if op._name != self._name
]
if not alike:
# No other operations of the same type. Early return
return []

# Return subset of `alike` with the same "token"
token = _tokenize_partial(self, ignore)
return [item for item in alike if _tokenize_partial(item, ignore) == token]

def _node_label_args(self):
"""Operands to include in the node label by `visualize`"""
return self.dependencies()

def _to_graphviz(
self,
rankdir="BT",
graph_attr=None,
node_attr=None,
edge_attr=None,
**kwargs,
):
from dask.dot import label, name

graphviz = import_required(
"graphviz",
"Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
"python library and the `graphviz` system library.\n\n"
"Please either conda or pip install as follows:\n\n"
" conda install python-graphviz # either conda install\n"
" python -m pip install graphviz # or pip install and follow installation instructions",
)

graph_attr = graph_attr or {}
node_attr = node_attr or {}
edge_attr = edge_attr or {}

graph_attr["rankdir"] = rankdir
node_attr["shape"] = "box"
node_attr["fontname"] = "helvetica"

graph_attr.update(kwargs)
g = graphviz.Digraph(
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
)

stack = [self]
seen = set()
dependencies = {}
while stack:
expr = stack.pop()

if expr._name in seen:
continue
seen.add(expr._name)

dependencies[expr] = set(expr.dependencies())
for dep in expr.dependencies():
stack.append(dep)

cache = {}
for expr in dependencies:
expr_name = name(expr)
attrs = {}

# Make node label
deps = [
funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
for dep in expr._node_label_args()
]
_label = funcname(type(expr))
if deps:
_label = f"{_label}({', '.join(deps)})" if deps else _label
node_label = label(_label, cache=cache)

attrs.setdefault("label", str(node_label))
attrs.setdefault("fontsize", "20")
g.node(expr_name, **attrs)

for expr, deps in dependencies.items():
expr_name = name(expr)
for dep in deps:
dep_name = name(dep)
g.edge(dep_name, expr_name)

return g

def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
"""
Visualize the expression graph.
Requires ``graphviz`` to be installed.
Parameters
----------
filename : str or None, optional
The name of the file to write to disk. If the provided `filename`
doesn't include an extension, '.png' will be used by default.
If `filename` is None, no file will be written, and the graph is
rendered in the Jupyter notebook only.
format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
Format in which to write output file. Default is 'svg'.
**kwargs
Additional keyword arguments to forward to ``to_graphviz``.
"""
from dask.dot import graphviz_to_file

g = self._to_graphviz(**kwargs)
graphviz_to_file(g, filename, format)
return g

def walk(self) -> Generator[Expr]:
"""Iterate through all expressions in the tree
Returns
-------
nodes
Generator of Expr instances in the graph.
Ordering is a depth-first search of the expression tree
"""
stack = [self]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)

for dep in node.dependencies():
stack.append(dep)

yield node

def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
"""Search the expression graph for a specific operation type
Parameters
----------
operation
The operation type to search for.
Returns
-------
nodes
Generator of `operation` instances. Ordering corresponds
to a depth-first search of the expression graph.
"""
assert (
isinstance(operation, tuple)
and all(issubclass(e, Expr) for e in operation)
or issubclass(operation, Expr)
), "`operation` must be`Expr` subclass)"
return (expr for expr in self.walk() if isinstance(expr, operation))


class Literal(Expr):
"""Represent a literal (known) value as an `Expr`"""
Expand Down Expand Up @@ -1927,6 +2204,17 @@ class Pos(Unaryop):
_operator_repr = "+"


class Tuple(Expr):
def __getitem__(self, other):
return self.operands[other]

def __len__(self):
return len(self.operands)

def __iter__(self):
return iter(self.operands)


class Partitions(Expr):
"""Select one or more partitions"""

Expand Down
14 changes: 13 additions & 1 deletion dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
to_numeric,
to_timedelta,
)
from dask_expr._expr import are_co_aligned
from dask_expr._expr import Tuple, are_co_aligned
from dask_expr._reductions import Len
from dask_expr._shuffle import Shuffle
from dask_expr.datasets import timeseries
Expand Down Expand Up @@ -1903,6 +1903,18 @@ def test_items(df, pdf):
assert_eq(expect_col, actual_col)


def test_combine_expr_with_tuple(pdf):
ddf1 = from_pandas(pdf, npartitions=2) + 1
ddf2 = from_pandas(pdf, npartitions=3) + 2

t = Tuple(ddf1.expr, ddf2.expr)
assert t[0]._name == ddf1._name
assert t[0].optimize()._name == t.optimize()[0]._name

assert t[1]._name == ddf2._name
assert t[1].optimize()._name == t.optimize()[1]._name


def test_index_index(df):
with pytest.raises(NotImplementedError, match="has no"):
df.index.index
Expand Down

0 comments on commit 93a0fd1

Please sign in to comment.