diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index a7620262..6f9a0f4f 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -48,6 +48,7 @@ M, derived_from, get_meta_library, + key_split, maybe_pluralize, memory_repr, put_lines, @@ -444,7 +445,12 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): state = new_collection(self.expr.lower_completely()) - return from_graph, (state._meta, state.divisions, state._name) + return from_graph, ( + state._meta, + state.divisions, + state.__dask_keys__(), + key_split(state._name), + ) def __getattr__(self, key): try: @@ -4473,7 +4479,9 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase: graph = ddf.dask if optimize: graph = ddf.__dask_optimize__(graph, ddf.__dask_keys__()) - return from_graph(graph, ddf._meta, ddf.divisions, ddf._name) + return from_graph( + graph, ddf._meta, ddf.divisions, ddf.__dask_keys__(), key_split(ddf._name) + ) def from_dask_array(x, columns=None, index=None, meta=None): diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 72e8cc1b..4e75e0a4 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -28,17 +28,36 @@ def _unpack_collections(o): class Expr: _parameters = [] _defaults = {} + _instances = weakref.WeakValueDictionary() - def __init__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): operands = list(args) - for parameter in type(self)._parameters[len(operands) :]: + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) except KeyError: - operands.append(type(self)._defaults[parameter]) + operands.append(cls._defaults[parameter]) assert not kwargs, kwargs - operands = [_unpack_collections(o) for o in operands] - self.operands = operands + inst = object.__new__(cls) + inst.operands = [_unpack_collections(o) for o in operands] + _name = inst._name + if _name in Expr._instances: + return Expr._instances[_name] + + Expr._instances[_name] = inst + return inst + + def _tune_down(self): + return None + + def _tune_up(self, parent): + return None + + def _cull_down(self): + return None + + def _cull_up(self, parent): + return None def __str__(self): s = ", ".join( @@ -204,28 +223,26 @@ def rewrite(self, kind: str): _continue = False # Rewrite this node - if down_name in expr.__dir__(): - out = getattr(expr, down_name)() + out = getattr(expr, down_name)() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out._name != expr._name: + expr = out + continue + + # Allow children to rewrite their parents + for child in expr.dependencies(): + out = getattr(child, up_name)(expr) if out is None: out = expr if not isinstance(out, Expr): return out - if out._name != expr._name: + if out is not expr and out._name != expr._name: expr = out - continue - - # Allow children to rewrite their parents - for child in expr.dependencies(): - if up_name in child.__dir__(): - out = getattr(child, up_name)(expr) - if out is None: - out = expr - if not isinstance(out, Expr): - return out - if out is not expr and out._name != expr._name: - expr = out - _continue = True - break + _continue = True + break if _continue: continue diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index cc2aa4ed..a0f864f7 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -20,6 +20,7 @@ class _DelayedExpr(Expr): # Wraps a Delayed object to make it an Expr for now. This is hacky and we should # integrate this properly... # TODO + _parameters = ["obj"] def __init__(self, obj): self.obj = obj diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 3d8b020c..25ef1abc 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -36,7 +36,7 @@ class FromGraph(IO): conversion from legacy dataframes. """ - _parameters = ["layer", "_meta", "divisions", "_name"] + _parameters = ["layer", "_meta", "divisions", "keys", "name_prefix"] @property def _meta(self): @@ -45,12 +45,19 @@ def _meta(self): def _divisions(self): return self.operand("divisions") - @property + @functools.cached_property def _name(self): - return self.operand("_name") + return ( + self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + ) def _layer(self): - return dict(self.operand("layer")) + dsk = dict(self.operand("layer")) + # The name may not actually match the layers name therefore rewrite this + # using an alias + for part, k in enumerate(self.operand("keys")): + dsk[(self._name, part)] = k + return dsk class BlockwiseIO(Blockwise, IO): diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 6d348b25..90c2aa53 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import functools import itertools import operator import warnings @@ -26,8 +25,9 @@ from dask.dataframe.io.parquet.utils import _split_user_options from dask.dataframe.io.utils import _is_local_fs from dask.delayed import delayed -from dask.utils import apply, natural_sort_key, typename +from dask.utils import apply, funcname, natural_sort_key, typename from fsspec.utils import stringify_path +from toolz import identity from dask_expr._expr import ( EQ, @@ -47,26 +47,15 @@ determine_column_projection, ) from dask_expr._reductions import Len -from dask_expr._util import _convert_to_list +from dask_expr._util import _convert_to_list, _tokenize_deterministic from dask_expr.io import BlockwiseIO, PartitionsFiltered NONE_LABEL = "__null_dask_index__" -_cached_dataset_info = {} -_CACHED_DATASET_SIZE = 10 _CACHED_PLAN_SIZE = 10 _cached_plan = {} -def _control_cached_dataset_info(key): - if ( - len(_cached_dataset_info) > _CACHED_DATASET_SIZE - and key not in _cached_dataset_info - ): - key_to_pop = list(_cached_dataset_info.keys())[0] - _cached_dataset_info.pop(key_to_pop) - - def _control_cached_plan(key): if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan: key_to_pop = list(_cached_plan.keys())[0] @@ -121,7 +110,7 @@ def _lower(self): class ToParquetData(Blockwise): _parameters = ToParquet._parameters - @cached_property + @property def io_func(self): return ToParquetFunctionWrapper( self.engine, @@ -257,7 +246,6 @@ def to_parquet( # Clear read_parquet caches in case we are # also reading from the overwritten path - _cached_dataset_info.clear() _cached_plan.clear() # Always skip divisions checks if divisions are unknown @@ -413,6 +401,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "kwargs", "_partitions", "_series", + "_dataset_info_cache", ] _defaults = { "columns": None, @@ -432,6 +421,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "kwargs": None, "_partitions": None, "_series": False, + "_dataset_info_cache": None, } _pq_length_stats = None _absorb_projections = True @@ -474,7 +464,21 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) @cached_property + def _name(self): + return ( + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(self.checksum, *self.operands) + ) + + @property + def checksum(self): + return self._dataset_info["checksum"] + + @property def _dataset_info(self): + if rv := self.operand("_dataset_info_cache"): + return rv # Process and split user options ( dataset_options, @@ -536,13 +540,25 @@ def _dataset_info(self): **other_options, }, ) - dataset_token = tokenize(*args) - if dataset_token not in _cached_dataset_info: - _control_cached_dataset_info(dataset_token) - _cached_dataset_info[dataset_token] = self.engine._collect_dataset_info( - *args - ) - dataset_info = _cached_dataset_info[dataset_token].copy() + dataset_info = self.engine._collect_dataset_info(*args) + checksum = [] + files_for_checksum = [] + if dataset_info["has_metadata_file"]: + if isinstance(self.path, list): + files_for_checksum = [ + next(path for path in self.path if path.endswith("_metadata")) + ] + else: + files_for_checksum = [self.path + fs.sep + "_metadata"] + else: + files_for_checksum = dataset_info["ds"].files + + for file in files_for_checksum: + # The checksum / file info is usually already cached by the fsspec + # FileSystem dir_cache since this info was already asked for in + # _collect_dataset_info + checksum.append(fs.checksum(file)) + dataset_info["checksum"] = tokenize(checksum) # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) @@ -558,6 +574,9 @@ def _dataset_info(self): dataset_info["all_columns"] = all_columns dataset_info["calculate_divisions"] = self.calculate_divisions + self.operands[ + type(self)._parameters.index("_dataset_info_cache") + ] = dataset_info return dataset_info @property @@ -571,10 +590,10 @@ def _meta(self): return meta[columns] return meta - @cached_property + @property def _io_func(self): if self._plan["empty"]: - return lambda x: x + return identity dataset_info = self._dataset_info return ParquetFunctionWrapper( self.engine, @@ -662,7 +681,7 @@ def _update_length_statistics(self): stat["num-rows"] for stat in _collect_pq_statistics(self) ) - @functools.cached_property + @property def _fusion_compression_factor(self): if self.operand("columns") is None: return 1 @@ -767,9 +786,11 @@ def _maybe_list(val): return [val] return [ - _maybe_list(val.to_list_tuple()) - if hasattr(val, "to_list_tuple") - else _maybe_list(val) + ( + _maybe_list(val.to_list_tuple()) + if hasattr(val, "to_list_tuple") + else _maybe_list(val) + ) for val in self ] diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 4d869ca3..9501c512 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -11,7 +11,7 @@ import dask.array as da import numpy as np import pytest -from dask.dataframe._compat import PANDAS_GE_210 +from dask.dataframe._compat import PANDAS_GE_210, PANDAS_GE_220 from dask.dataframe.utils import UNKNOWN_CATEGORIES from dask.utils import M @@ -983,14 +983,15 @@ def test_broadcast(pdf, df): def test_persist(pdf, df): a = df + 2 + a *= 2 b = a.persist() assert_eq(a, b) assert len(a.__dask_graph__()) > len(b.__dask_graph__()) - assert len(b.__dask_graph__()) == b.npartitions + assert len(b.__dask_graph__()) == 2 * b.npartitions - assert_eq(b.y.sum(), (pdf + 2).y.sum()) + assert_eq(b.y.sum(), ((pdf + 2) * 2).y.sum()) def test_index(pdf, df): @@ -1035,6 +1036,7 @@ def test_head_down(df): assert not isinstance(optimized.expr, expr.Head) +@pytest.mark.skipif(not PANDAS_GE_220, reason="not implemented") def test_case_when(pdf, df): result = df.x.case_when([(df.x.eq(1), 1), (df.y == 10, 2.5)]) expected = pdf.x.case_when([(pdf.x.eq(1), 1), (pdf.y == 10, 2.5)]) diff --git a/dask_expr/tests/test_datasets.py b/dask_expr/tests/test_datasets.py index 1541beb2..aa210dce 100644 --- a/dask_expr/tests/test_datasets.py +++ b/dask_expr/tests/test_datasets.py @@ -53,7 +53,7 @@ def test_persist(): b = a.persist() assert_eq(a, b) - assert len(b.dask) == b.npartitions + assert len(b.dask) == 2 * b.npartitions def test_lengths():