diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index a65ae819b44..16b2c8959e2 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -537,6 +537,12 @@ def to_cudf_dispatch_from_pandas(data, nan_as_null=None, **kwargs): return cudf.from_pandas(data, nan_as_null=nan_as_null) +@to_cudf_dispatch.register((cudf.DataFrame, cudf.Series, cudf.Index)) +def to_cudf_dispatch_from_cudf(data, **kwargs): + _unsupported_kwargs("cudf", "cudf", kwargs) + return data + + # Define "cudf" backend engine to be registered with Dask class CudfBackendEntrypoint(DataFrameBackendEntrypoint): """Backend-entrypoint class for Dask-DataFrame @@ -643,20 +649,20 @@ class CudfDXBackendEntrypoint(DataFrameBackendEntrypoint): Examples -------- >>> import dask - >>> import dask_expr + >>> import dask_expr as dx >>> with dask.config.set({"dataframe.backend": "cudf"}): ... ddf = dx.from_dict({"a": range(10)}) >>> type(ddf._meta) """ - @classmethod - def to_backend_dispatch(cls): - return CudfBackendEntrypoint.to_backend_dispatch() + @staticmethod + def to_backend(data, **kwargs): + import dask_expr as dx - @classmethod - def to_backend(cls, *args, **kwargs): - return CudfBackendEntrypoint.to_backend(*args, **kwargs) + from dask_cudf.expr._expr import ToCudfBackend + + return dx.new_collection(ToCudfBackend(data, kwargs)) @staticmethod def from_dict( diff --git a/python/dask_cudf/dask_cudf/expr/_expr.py b/python/dask_cudf/dask_cudf/expr/_expr.py index 8fccaccb695..8a2c50d3fe7 100644 --- a/python/dask_cudf/dask_cudf/expr/_expr.py +++ b/python/dask_cudf/dask_cudf/expr/_expr.py @@ -4,12 +4,41 @@ import dask_expr._shuffle as _shuffle_module from dask_expr import new_collection from dask_expr._cumulative import CumulativeBlockwise -from dask_expr._expr import Expr, VarColumns +from dask_expr._expr import Elemwise, Expr, VarColumns from dask_expr._reductions import Reduction, Var from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty from dask.dataframe.dispatch import is_categorical_dtype +import cudf + +## +## Custom expressions +## + + +class ToCudfBackend(Elemwise): + # TODO: Inherit from ToBackend when rapids-dask-dependency + # is pinned to dask>=2024.8.1 + _parameters = ["frame", "options"] + _projection_passthrough = True + _filter_passthrough = True + _preserves_partitioning_information = True + + @staticmethod + def operation(df, options): + from dask_cudf.backends import to_cudf_dispatch + + return to_cudf_dispatch(df, **options) + + def _simplify_down(self): + if isinstance( + self.frame._meta, (cudf.DataFrame, cudf.Series, cudf.Index) + ): + # We already have cudf data + return self.frame + + ## ## Custom expression patching ## diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 174923c2c7e..905d8c08135 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -15,7 +15,11 @@ import cudf import dask_cudf -from dask_cudf.tests.utils import skip_dask_expr, xfail_dask_expr +from dask_cudf.tests.utils import ( + require_dask_expr, + skip_dask_expr, + xfail_dask_expr, +) def test_from_dict_backend_dispatch(): @@ -993,3 +997,13 @@ def test_series_isin_error(): ser.isin([1, 5, "a"]) with pytest.raises(TypeError): ddf.isin([1, 5, "a"]).compute() + + +@require_dask_expr() +def test_to_backend_simplify(): + # Check that column projection is not blocked by to_backend + with dask.config.set({"dataframe.backend": "pandas"}): + df = dd.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}, npartitions=2) + df2 = df.to_backend("cudf")[["y"]].simplify() + df3 = df[["y"]].to_backend("cudf").to_backend("cudf").simplify() + assert df2._name == df3._name diff --git a/python/dask_cudf/dask_cudf/tests/utils.py b/python/dask_cudf/dask_cudf/tests/utils.py index c7dedbb6b4a..cc0c6899804 100644 --- a/python/dask_cudf/dask_cudf/tests/utils.py +++ b/python/dask_cudf/dask_cudf/tests/utils.py @@ -48,3 +48,7 @@ def xfail_dask_expr(reason=_default_reason, lt_version=None): else: xfail = QUERY_PLANNING_ON return pytest.mark.xfail(xfail, reason=reason) + + +def require_dask_expr(reason="requires dask-expr"): + return pytest.mark.skipif(not QUERY_PLANNING_ON, reason=reason)