Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expr as singleton #798

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
M,
derived_from,
get_meta_library,
key_split,
maybe_pluralize,
memory_repr,
put_lines,
Expand Down Expand Up @@ -444,7 +445,12 @@ def __dask_postcompute__(self):

def __dask_postpersist__(self):
state = new_collection(self.expr.lower_completely())
return from_graph, (state._meta, state.divisions, state._name)
return from_graph, (
state._meta,
state.divisions,
state.__dask_keys__(),
key_split(state._name),
)

def __getattr__(self, key):
try:
Expand Down Expand Up @@ -4473,7 +4479,9 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase:
graph = ddf.dask
if optimize:
graph = ddf.__dask_optimize__(graph, ddf.__dask_keys__())
return from_graph(graph, ddf._meta, ddf.divisions, ddf._name)
return from_graph(
graph, ddf._meta, ddf.divisions, ddf.__dask_keys__(), key_split(ddf._name)
)


def from_dask_array(x, columns=None, index=None, meta=None):
Expand Down
61 changes: 39 additions & 22 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,36 @@ def _unpack_collections(o):
class Expr:
_parameters = []
_defaults = {}
_instances = weakref.WeakValueDictionary()

def __init__(self, *args, **kwargs):
def __new__(cls, *args, **kwargs):
operands = list(args)
for parameter in type(self)._parameters[len(operands) :]:
for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(type(self)._defaults[parameter])
operands.append(cls._defaults[parameter])
assert not kwargs, kwargs
operands = [_unpack_collections(o) for o in operands]
self.operands = operands
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]

Expr._instances[_name] = inst
return inst

def _tune_down(self):
return None

def _tune_up(self, parent):
return None

def _cull_down(self):
return None

def _cull_up(self, parent):
return None

def __str__(self):
s = ", ".join(
Expand Down Expand Up @@ -204,28 +223,26 @@ def rewrite(self, kind: str):
_continue = False

# Rewrite this node
if down_name in expr.__dir__():
out = getattr(expr, down_name)()
out = getattr(expr, down_name)()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
if out is not expr and out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
if up_name in child.__dir__():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
_continue = True
break
_continue = True
break

if _continue:
continue
Expand Down
1 change: 1 addition & 0 deletions dask_expr/io/_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class _DelayedExpr(Expr):
# Wraps a Delayed object to make it an Expr for now. This is hacky and we should
# integrate this properly...
# TODO
_parameters = ["obj"]

def __init__(self, obj):
self.obj = obj
Expand Down
15 changes: 11 additions & 4 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FromGraph(IO):
conversion from legacy dataframes.
"""

_parameters = ["layer", "_meta", "divisions", "_name"]
_parameters = ["layer", "_meta", "divisions", "keys", "name_prefix"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FromGraph thing got a little odd. For context, this thing is used in two ways

  1. As a wrapper for futures when persist is called
  2. As a wrapper for legacy dataframes

The previous implementation accepted the _name as an input argument. For the persisted dataframe, this was the name of the original expression. For the wrapped one, it is the name of the legacy dataframe.

The problem is that the legacy dataframe does not have the same strict uniqueness guarantees as the expressions do so setting duplicate names is very easy. In fact, our tests where doing just that! (and still are). This caused the Expr.__new__ to deduplicate and effectively ignore the second dataframe... oops.

For persist it is also a little odd since if the FromGraph expression is inheriting the exact name of it's ancestor, there exist now two expressions of a different type with the same name. This is odd.

Effectively, with this chosen singleton implementation, setting the _name explicitly instead of calculating it using a hash is a cardinal sin and will cause all sorts of weird things to happen.

Now, this can be fixed but that has a rather big caveat. If I have to redefine the name of the expression, I actually also have to rewrite the graph! Many places in dask-expr are (imo incorrectly) assuming that output key names of a dataframe layer/expression are universally built as (df._name, i) and are hard coding this when implementing their own layer (instead of iterating over i, iterating over df.__dask_keys__() would maintain the abstraction). This rewrite adds effectively another layer of keys. In reality this is really ugly since when computing something on top of a persisted dataframe, there will always be this dummy key in between.

Alternatively, I could make the singleton deduplication type aware to give the FromGraph thing an excuse to overwrite the name. However, if we truly stick with singletons that are based on the name, I would prefer the name to actually be unique which required all implementations to stop hard coding keys of another expression/dataframe and iterate properly over the __dask_keys__

Copy link
Collaborator

@phofl phofl Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, this can be fixed but that has a rather big caveat. If I have to redefine the name of the expression, I actually also have to rewrite the graph! Many places in dask-expr are (imo incorrectly) assuming that output key names of a dataframe layer/expression are universally built as (df._name, i) and are hard coding this when implementing their own layer (instead of iterating over i, iterating over df.dask_keys() would maintain the abstraction). This rewrite adds effectively another layer of keys. In reality this is really ugly since when computing something on top of a persisted dataframe, there will always be this dummy key in between.

This is a good point, we should fix this instead of relying on df._name and i

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will investigate if it is possible to just use __dask_keys__ everywhere but I'd prefer doing this in a follow up

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes totally agree, this should definitely be a follow up (also doesn't have to be you, I could pick this up as well)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got started on this, see

c3c01ed

That commit replaces all occurrences that match \((\w+(_\w+)?)(?<!self)\._name, but something appears to be still missing. It is possible and not as ugly as I thought it would be


@property
def _meta(self):
Expand All @@ -45,12 +45,19 @@ def _meta(self):
def _divisions(self):
return self.operand("divisions")

@property
@functools.cached_property
def _name(self):
return self.operand("_name")
return (
self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands)
)

def _layer(self):
return dict(self.operand("layer"))
dsk = dict(self.operand("layer"))
# The name may not actually match the layers name therefore rewrite this
# using an alias
for part, k in enumerate(self.operand("keys")):
dsk[(self._name, part)] = k
return dsk
Comment on lines 54 to +60
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intuitive fix for me would've been to overwrite __dask_keys__ instead to point to the appropriate keys but as I explained above, the implementation of Expr.__dask_keys__ is hard coded in many places.



class BlockwiseIO(Blockwise, IO):
Expand Down
79 changes: 50 additions & 29 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import functools
import itertools
import operator
import warnings
Expand All @@ -26,8 +25,9 @@
from dask.dataframe.io.parquet.utils import _split_user_options
from dask.dataframe.io.utils import _is_local_fs
from dask.delayed import delayed
from dask.utils import apply, natural_sort_key, typename
from dask.utils import apply, funcname, natural_sort_key, typename
from fsspec.utils import stringify_path
from toolz import identity

from dask_expr._expr import (
EQ,
Expand All @@ -47,26 +47,15 @@
determine_column_projection,
)
from dask_expr._reductions import Len
from dask_expr._util import _convert_to_list
from dask_expr._util import _convert_to_list, _tokenize_deterministic
from dask_expr.io import BlockwiseIO, PartitionsFiltered

NONE_LABEL = "__null_dask_index__"

_cached_dataset_info = {}
_CACHED_DATASET_SIZE = 10
_CACHED_PLAN_SIZE = 10
_cached_plan = {}


def _control_cached_dataset_info(key):
if (
len(_cached_dataset_info) > _CACHED_DATASET_SIZE
and key not in _cached_dataset_info
):
key_to_pop = list(_cached_dataset_info.keys())[0]
_cached_dataset_info.pop(key_to_pop)


def _control_cached_plan(key):
if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan:
key_to_pop = list(_cached_plan.keys())[0]
Expand Down Expand Up @@ -121,7 +110,7 @@ def _lower(self):
class ToParquetData(Blockwise):
_parameters = ToParquet._parameters

@cached_property
@property
def io_func(self):
return ToParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -257,7 +246,6 @@ def to_parquet(

# Clear read_parquet caches in case we are
# also reading from the overwritten path
_cached_dataset_info.clear()
_cached_plan.clear()

# Always skip divisions checks if divisions are unknown
Expand Down Expand Up @@ -413,6 +401,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs",
"_partitions",
"_series",
"_dataset_info_cache",
]
_defaults = {
"columns": None,
Expand All @@ -432,6 +421,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs": None,
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
}
_pq_length_stats = None
_absorb_projections = True
Expand Down Expand Up @@ -474,7 +464,21 @@ def _simplify_up(self, parent, dependents):
return Literal(sum(_lengths))

@cached_property
def _name(self):
return (
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(self.checksum, *self.operands)
)
Comment on lines +467 to +472
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this checksum is part of the _name allowing us to differentiate expressions that point to modified states of the dataset. It also allows us to reuse already cached "plans / divisions" if the dataset did not change which is the most common case


@property
def checksum(self):
return self._dataset_info["checksum"]

@property
def _dataset_info(self):
if rv := self.operand("_dataset_info_cache"):
return rv
# Process and split user options
(
dataset_options,
Expand Down Expand Up @@ -536,13 +540,25 @@ def _dataset_info(self):
**other_options,
},
)
dataset_token = tokenize(*args)
if dataset_token not in _cached_dataset_info:
_control_cached_dataset_info(dataset_token)
_cached_dataset_info[dataset_token] = self.engine._collect_dataset_info(
*args
)
dataset_info = _cached_dataset_info[dataset_token].copy()
dataset_info = self.engine._collect_dataset_info(*args)
checksum = []
files_for_checksum = []
if dataset_info["has_metadata_file"]:
if isinstance(self.path, list):
files_for_checksum = [
next(path for path in self.path if path.endswith("_metadata"))
]
else:
files_for_checksum = [self.path + fs.sep + "_metadata"]
else:
files_for_checksum = dataset_info["ds"].files

for file in files_for_checksum:
# The checksum / file info is usually already cached by the fsspec
# FileSystem dir_cache since this info was already asked for in
# _collect_dataset_info
checksum.append(fs.checksum(file))
dataset_info["checksum"] = tokenize(checksum)
Comment on lines +556 to +561
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To deal with the cache consistency problem described in #800 I am calculating a checksum here. For s3 this falls back to using the ETag provided in the listdir response. This should not add any overhead since this stuff is already cached by fsspec.
We're either taking the checksum of the metadata file or of all files that we iterate over. At this point this listdir operation is already done so the checksum identifies every dataset uniquely. Since adding this checksum to the dataset_info, this also guarantees that the cache for the plan is invalidated if the dataset changes.


# Infer meta, accounting for index and columns arguments.
meta = self.engine._create_dd_meta(dataset_info)
Expand All @@ -558,6 +574,9 @@ def _dataset_info(self):
dataset_info["all_columns"] = all_columns
dataset_info["calculate_divisions"] = self.calculate_divisions

self.operands[
type(self)._parameters.index("_dataset_info_cache")
] = dataset_info
return dataset_info

@property
Expand All @@ -571,10 +590,10 @@ def _meta(self):
return meta[columns]
return meta

@cached_property
@property
def _io_func(self):
if self._plan["empty"]:
return lambda x: x
return identity
dataset_info = self._dataset_info
return ParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -662,7 +681,7 @@ def _update_length_statistics(self):
stat["num-rows"] for stat in _collect_pq_statistics(self)
)

@functools.cached_property
@property
def _fusion_compression_factor(self):
if self.operand("columns") is None:
return 1
Expand Down Expand Up @@ -767,9 +786,11 @@ def _maybe_list(val):
return [val]

return [
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
(
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
)
for val in self
]

Expand Down
8 changes: 5 additions & 3 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dask.array as da
import numpy as np
import pytest
from dask.dataframe._compat import PANDAS_GE_210
from dask.dataframe._compat import PANDAS_GE_210, PANDAS_GE_220
from dask.dataframe.utils import UNKNOWN_CATEGORIES
from dask.utils import M

Expand Down Expand Up @@ -983,14 +983,15 @@ def test_broadcast(pdf, df):

def test_persist(pdf, df):
a = df + 2
a *= 2
b = a.persist()

assert_eq(a, b)
assert len(a.__dask_graph__()) > len(b.__dask_graph__())

assert len(b.__dask_graph__()) == b.npartitions
assert len(b.__dask_graph__()) == 2 * b.npartitions

assert_eq(b.y.sum(), (pdf + 2).y.sum())
assert_eq(b.y.sum(), ((pdf + 2) * 2).y.sum())


def test_index(pdf, df):
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def test_head_down(df):
assert not isinstance(optimized.expr, expr.Head)


@pytest.mark.skipif(not PANDAS_GE_220, reason="not implemented")
def test_case_when(pdf, df):
result = df.x.case_when([(df.x.eq(1), 1), (df.y == 10, 2.5)])
expected = pdf.x.case_when([(pdf.x.eq(1), 1), (pdf.y == 10, 2.5)])
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_persist():
b = a.persist()

assert_eq(a, b)
assert len(b.dask) == b.npartitions
assert len(b.dask) == 2 * b.npartitions


def test_lengths():
Expand Down
Loading