diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index e42a0388c..dd34c7171 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -193,6 +193,9 @@ class FrameBase(DaskMethodsMixin): ) __dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk) + def __dask_tokenize__(self): + return self.expr._name + def __init__(self, expr): self._expr = expr @@ -262,14 +265,7 @@ def compute(self, fuse=True, combine_similar=True, **kwargs): return DaskMethodsMixin.compute(out, **kwargs) def __dask_graph__(self): - out = self.expr - out = out.lower_completely() - return out.__dask_graph__() - - def __dask_keys__(self): - out = self.expr - out = out.lower_completely() - return out.__dask_keys__() + return self.expr def simplify(self): return new_collection(self.expr.simplify()) @@ -286,6 +282,9 @@ def optimize(self, combine_similar: bool = True, fuse: bool = True): def dask(self): return self.__dask_graph__() + def finalize_compute(self): + return new_collection(Repartition(self.expr, 1)) + def __dask_postcompute__(self): state = new_collection(self.expr.lower_completely()) if type(self) != type(state): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 8671c05ef..61a587e5d 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -497,6 +497,283 @@ def name(self): def dtypes(self): return self._meta.dtypes + @property + def _meta(self): + raise NotImplementedError() + + def materialize(self): + """Traverse expression tree, collect layers""" + stack = [self] + seen = set() + layers = [] + while stack: + expr = stack.pop() + + if expr._name in seen: + continue + seen.add(expr._name) + + layers.append(expr._layer()) + for operand in expr.dependencies(): + stack.append(operand) + + return toolz.merge(layers) + + def __dask_keys__(self): + return [(self._name, i) for i in range(self.npartitions)] + + def substitute(self, old, new) -> Expr: + """Substitute a specific term within the expression + + Note that replacing non-`Expr` terms may produce + unexpected results, and is not recommended. + Substituting boolean values is not allowed. + + Parameters + ---------- + old: + Old term to find and replace. + new: + New term to replace instances of `old` with. + + Examples + -------- + >>> (df + 10).substitute(10, 20) + df + 20 + """ + + # Check if we are replacing a literal + if isinstance(old, Expr): + substitute_literal = False + if self._name == old._name: + return new + else: + substitute_literal = True + if isinstance(old, bool): + raise TypeError("Arguments to `substitute` cannot be bool.") + + new_exprs = [] + update = False + for operand in self.operands: + if isinstance(operand, Expr): + val = operand.substitute(old, new) + if operand._name != val._name: + update = True + new_exprs.append(val) + elif ( + isinstance(self, Fused) + and isinstance(operand, list) + and all(isinstance(op, Expr) for op in operand) + ): + # Special handling for `Fused`. + # We make no promise to dive through a + # list operand in general, but NEED to + # do so for the `Fused.exprs` operand. + val = [] + for op in operand: + val.append(op.substitute(old, new)) + if val[-1]._name != op._name: + update = True + new_exprs.append(val) + elif ( + substitute_literal + and not isinstance(operand, bool) + and isinstance(operand, type(old)) + and operand == old + ): + new_exprs.append(new) + update = True + else: + new_exprs.append(operand) + + if update: # Only recreate if something changed + return type(self)(*new_exprs) + return self + + def substitute_parameters(self, substitutions: dict) -> Expr: + """Substitute specific `Expr` parameters + + Parameters + ---------- + substitutions: + Mapping of parameter keys to new values. Keys that + are not found in ``self._parameters`` will be ignored. + """ + if not substitutions: + return self + + changed = False + new_operands = [] + for i, operand in enumerate(self.operands): + if i < len(self._parameters) and self._parameters[i] in substitutions: + new_operands.append(substitutions[self._parameters[i]]) + changed = True + else: + new_operands.append(operand) + if changed: + return type(self)(*new_operands) + return self + + def _find_similar_operations(self, root: Expr, ignore: list | None = None): + # Find operations with the same type and operands. + # Parameter keys specified by `ignore` will not be + # included in the operand comparison + alike = [ + op for op in root.find_operations(type(self)) if op._name != self._name + ] + if not alike: + # No other operations of the same type. Early return + return [] + + # Return subset of `alike` with the same "token" + token = _tokenize_partial(self, ignore) + return [item for item in alike if _tokenize_partial(item, ignore) == token] + + def _node_label_args(self): + """Operands to include in the node label by `visualize`""" + return self.dependencies() + + def _to_graphviz( + self, + rankdir="BT", + graph_attr=None, + node_attr=None, + edge_attr=None, + **kwargs, + ): + from dask.dot import label, name + + graphviz = import_required( + "graphviz", + "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " + "python library and the `graphviz` system library.\n\n" + "Please either conda or pip install as follows:\n\n" + " conda install python-graphviz # either conda install\n" + " python -m pip install graphviz # or pip install and follow installation instructions", + ) + + graph_attr = graph_attr or {} + node_attr = node_attr or {} + edge_attr = edge_attr or {} + + graph_attr["rankdir"] = rankdir + node_attr["shape"] = "box" + node_attr["fontname"] = "helvetica" + + graph_attr.update(kwargs) + g = graphviz.Digraph( + graph_attr=graph_attr, + node_attr=node_attr, + edge_attr=edge_attr, + ) + + stack = [self] + seen = set() + dependencies = {} + while stack: + expr = stack.pop() + + if expr._name in seen: + continue + seen.add(expr._name) + + dependencies[expr] = set(expr.dependencies()) + for dep in expr.dependencies(): + stack.append(dep) + + cache = {} + for expr in dependencies: + expr_name = name(expr) + attrs = {} + + # Make node label + deps = [ + funcname(type(dep)) if isinstance(dep, Expr) else str(dep) + for dep in expr._node_label_args() + ] + _label = funcname(type(expr)) + if deps: + _label = f"{_label}({', '.join(deps)})" if deps else _label + node_label = label(_label, cache=cache) + + attrs.setdefault("label", str(node_label)) + attrs.setdefault("fontsize", "20") + g.node(expr_name, **attrs) + + for expr, deps in dependencies.items(): + expr_name = name(expr) + for dep in deps: + dep_name = name(dep) + g.edge(dep_name, expr_name) + + return g + + def visualize(self, filename="dask-expr.svg", format=None, **kwargs): + """ + Visualize the expression graph. + Requires ``graphviz`` to be installed. + + Parameters + ---------- + filename : str or None, optional + The name of the file to write to disk. If the provided `filename` + doesn't include an extension, '.png' will be used by default. + If `filename` is None, no file will be written, and the graph is + rendered in the Jupyter notebook only. + format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional + Format in which to write output file. Default is 'svg'. + **kwargs + Additional keyword arguments to forward to ``to_graphviz``. + """ + from dask.dot import graphviz_to_file + + g = self._to_graphviz(**kwargs) + graphviz_to_file(g, filename, format) + return g + + def walk(self) -> Generator[Expr]: + """Iterate through all expressions in the tree + + Returns + ------- + nodes + Generator of Expr instances in the graph. + Ordering is a depth-first search of the expression tree + """ + stack = [self] + seen = set() + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + for dep in node.dependencies(): + stack.append(dep) + + yield node + + def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: + """Search the expression graph for a specific operation type + + Parameters + ---------- + operation + The operation type to search for. + + Returns + ------- + nodes + Generator of `operation` instances. Ordering corresponds + to a depth-first search of the expression graph. + """ + assert ( + isinstance(operation, tuple) + and all(issubclass(e, Expr) for e in operation) + or issubclass(operation, Expr) + ), "`operation` must be`Expr` subclass)" + return (expr for expr in self.walk() if isinstance(expr, operation)) + class Literal(Expr): """Represent a literal (known) value as an `Expr`""" @@ -1927,6 +2204,17 @@ class Pos(Unaryop): _operator_repr = "+" +class Tuple(Expr): + def __getitem__(self, other): + return self.operands[other] + + def __len__(self): + return len(self.operands) + + def __iter__(self): + return iter(self.operands) + + class Partitions(Expr): """Select one or more partitions""" diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 8fd6690c7..c50874e55 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -23,7 +23,7 @@ to_numeric, to_timedelta, ) -from dask_expr._expr import are_co_aligned +from dask_expr._expr import Tuple, are_co_aligned from dask_expr._reductions import Len from dask_expr._shuffle import Shuffle from dask_expr.datasets import timeseries @@ -1903,6 +1903,18 @@ def test_items(df, pdf): assert_eq(expect_col, actual_col) +def test_combine_expr_with_tuple(pdf): + ddf1 = from_pandas(pdf, npartitions=2) + 1 + ddf2 = from_pandas(pdf, npartitions=3) + 2 + + t = Tuple(ddf1.expr, ddf2.expr) + assert t[0]._name == ddf1._name + assert t[0].optimize()._name == t.optimize()[0]._name + + assert t[1]._name == ddf2._name + assert t[1].optimize()._name == t.optimize()[1]._name + + def test_index_index(df): with pytest.raises(NotImplementedError, match="has no"): df.index.index