diff --git a/python/dask_cudf/dask_cudf/__init__.py b/python/dask_cudf/dask_cudf/__init__.py index f9df22cc436..cc17e71039a 100644 --- a/python/dask_cudf/dask_cudf/__init__.py +++ b/python/dask_cudf/dask_cudf/__init__.py @@ -1,21 +1,19 @@ # Copyright (c) 2018-2024, NVIDIA CORPORATION. -from dask import config - -# For dask>2024.2.0, we can silence the loud deprecation -# warning before importing `dask.dataframe` (this won't -# do anything for dask==2024.2.0) -config.set({"dataframe.query-planning-warning": False}) +import warnings +from importlib import import_module -import dask.dataframe as dd # noqa: E402 +from dask import config +import dask.dataframe as dd from dask.dataframe import from_delayed # noqa: E402 import cudf # noqa: E402 from . import backends # noqa: E402, F401 from ._version import __git_commit__, __version__ # noqa: E402, F401 -from .core import concat, from_cudf, from_dask_dataframe # noqa: E402 -from .expr import QUERY_PLANNING_ON # noqa: E402 +from .core import concat, from_cudf, DataFrame, Index, Series # noqa: F401 + +QUERY_PLANNING_ON = dd.DASK_EXPR_ENABLED def read_csv(*args, **kwargs): @@ -38,26 +36,44 @@ def read_parquet(*args, **kwargs): return dd.read_parquet(*args, **kwargs) -def raise_not_implemented_error(attr_name): +def _deprecated_api(old_api, new_api=None, rec=None): def inner_func(*args, **kwargs): + if new_api: + # Use alternative + msg = f"{old_api} is now deprecated. " + msg += rec or f"Please use {new_api} instead." + warnings.warn(msg, FutureWarning) + new_attr = new_api.split(".") + module = import_module(".".join(new_attr[:-1])) + return getattr(module, new_attr[-1])(*args, **kwargs) + + # No alternative - raise an error raise NotImplementedError( - f"Top-level {attr_name} API is not available for dask-expr." + f"{old_api} is no longer supported. " + (rec or "") ) return inner_func if QUERY_PLANNING_ON: - from .expr._collection import DataFrame, Index, Series + from ._expr.expr import _patch_dask_expr + from . import io # noqa: F401 - groupby_agg = raise_not_implemented_error("groupby_agg") + groupby_agg = _deprecated_api("dask_cudf.groupby_agg") read_text = DataFrame.read_text - to_orc = raise_not_implemented_error("to_orc") + _patch_dask_expr() else: - from .core import DataFrame, Index, Series # noqa: F401 - from .groupby import groupby_agg # noqa: F401 - from .io import read_text, to_orc # noqa: F401 + from ._legacy.groupby import groupby_agg # noqa: F401 + from ._legacy.io import read_text # noqa: F401 + from . import io # noqa: F401 + + +to_orc = _deprecated_api( + "dask_cudf.to_orc", + new_api="dask_cudf._legacy.io.to_orc", + rec="Please use DataFrame.to_orc instead.", +) __all__ = [ @@ -65,7 +81,6 @@ def inner_func(*args, **kwargs): "Series", "Index", "from_cudf", - "from_dask_dataframe", "concat", "from_delayed", ] diff --git a/python/dask_cudf/dask_cudf/_expr/__init__.py b/python/dask_cudf/dask_cudf/_expr/__init__.py new file mode 100644 index 00000000000..3c827d4ff59 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_expr/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. diff --git a/python/dask_cudf/dask_cudf/accessors.py b/python/dask_cudf/dask_cudf/_expr/accessors.py similarity index 100% rename from python/dask_cudf/dask_cudf/accessors.py rename to python/dask_cudf/dask_cudf/_expr/accessors.py diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/_expr/collection.py similarity index 88% rename from python/dask_cudf/dask_cudf/expr/_collection.py rename to python/dask_cudf/dask_cudf/_expr/collection.py index 907abaa2bfc..fdf7d8630e9 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/_expr/collection.py @@ -34,22 +34,6 @@ class CudfFrameBase(FrameBase): - def to_dask_dataframe(self, **kwargs): - """Create a dask.dataframe object from a dask_cudf object - - WARNING: This API is deprecated, and may not work properly. - Please use `*.to_backend("pandas")` to convert the - underlying data to pandas. - """ - - warnings.warn( - "The `to_dask_dataframe` API is now deprecated. " - "Please use `*.to_backend('pandas')` instead.", - FutureWarning, - ) - - return self.to_backend("pandas", **kwargs) - def _prepare_cov_corr(self, min_periods, numeric_only): # Upstream version of this method sets min_periods # to 2 by default (which is not supported by cudf) @@ -94,7 +78,7 @@ def var( def rename_axis( self, mapper=no_default, index=no_default, columns=no_default, axis=0 ): - from dask_cudf.expr._expr import RenameAxisCudf + from dask_cudf._expr.expr import RenameAxisCudf return new_collection( RenameAxisCudf( @@ -136,7 +120,7 @@ def groupby( dropna=None, **kwargs, ): - from dask_cudf.expr._groupby import GroupBy + from dask_cudf._expr.groupby import GroupBy if isinstance(by, FrameBase) and not isinstance(by, DXSeries): raise ValueError( @@ -169,13 +153,16 @@ def groupby( ) def to_orc(self, *args, **kwargs): - return self.to_legacy_dataframe().to_orc(*args, **kwargs) + from dask_cudf._legacy.io import to_orc + + return to_orc(self, *args, **kwargs) + # return self.to_legacy_dataframe().to_orc(*args, **kwargs) @staticmethod def read_text(*args, **kwargs): from dask_expr import from_legacy_dataframe - from dask_cudf.io.text import read_text as legacy_read_text + from dask_cudf._legacy.io.text import read_text as legacy_read_text ddf = legacy_read_text(*args, **kwargs) return from_legacy_dataframe(ddf) @@ -183,19 +170,19 @@ def read_text(*args, **kwargs): class Series(DXSeries, CudfFrameBase): def groupby(self, by, **kwargs): - from dask_cudf.expr._groupby import SeriesGroupBy + from dask_cudf._expr.groupby import SeriesGroupBy return SeriesGroupBy(self, by, **kwargs) @cached_property def list(self): - from dask_cudf.accessors import ListMethods + from dask_cudf._expr.accessors import ListMethods return ListMethods(self) @cached_property def struct(self): - from dask_cudf.accessors import StructMethods + from dask_cudf._expr.accessors import StructMethods return StructMethods(self) diff --git a/python/dask_cudf/dask_cudf/_expr/expr.py b/python/dask_cudf/dask_cudf/_expr/expr.py new file mode 100644 index 00000000000..8b91e53604c --- /dev/null +++ b/python/dask_cudf/dask_cudf/_expr/expr.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import functools + +import dask_expr._shuffle as _shuffle_module +from dask_expr import new_collection +from dask_expr._cumulative import CumulativeBlockwise +from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns +from dask_expr._reductions import Reduction, Var + +from dask.dataframe.core import ( + is_dataframe_like, + make_meta, + meta_nonempty, +) +from dask.dataframe.dispatch import is_categorical_dtype +from dask.typing import no_default + +import cudf + +## +## Custom expressions +## + + +class RenameAxisCudf(RenameAxis): + # TODO: Remove this after rename_axis is supported in cudf + # (See: https://github.com/rapidsai/cudf/issues/16895) + @staticmethod + def operation(df, index=no_default, **kwargs): + if index != no_default: + df.index.name = index + return df + raise NotImplementedError( + "Only `index` is supported for the cudf backend" + ) + + +class ToCudfBackend(Elemwise): + # TODO: Inherit from ToBackend when rapids-dask-dependency + # is pinned to dask>=2024.8.1 + _parameters = ["frame", "options"] + _projection_passthrough = True + _filter_passthrough = True + _preserves_partitioning_information = True + + @staticmethod + def operation(df, options): + from dask_cudf.backends import to_cudf_dispatch + + return to_cudf_dispatch(df, **options) + + def _simplify_down(self): + if isinstance( + self.frame._meta, (cudf.DataFrame, cudf.Series, cudf.Index) + ): + # We already have cudf data + return self.frame + + +## +## Custom expression patching +## + + +# This can be removed after cudf#15176 is addressed. +# See: https://github.com/rapidsai/cudf/issues/15176 +class PatchCumulativeBlockwise(CumulativeBlockwise): + @property + def _args(self) -> list: + return self.operands[:1] + + @property + def _kwargs(self) -> dict: + # Must pass axis and skipna as kwargs in cudf + return {"axis": self.axis, "skipna": self.skipna} + + +# The upstream Var code uses `Series.values`, and relies on numpy +# for most of the logic. Unfortunately, cudf -> cupy conversion +# is not supported for data containing null values. Therefore, +# we must implement our own version of Var for now. This logic +# is mostly copied from dask-cudf. + + +class VarCudf(Reduction): + # Uses the parallel version of Welford's online algorithm (Chan '79) + # (http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf) + _parameters = [ + "frame", + "skipna", + "ddof", + "numeric_only", + "split_every", + ] + _defaults = { + "skipna": True, + "ddof": 1, + "numeric_only": False, + "split_every": False, + } + + @functools.cached_property + def _meta(self): + return make_meta( + meta_nonempty(self.frame._meta).var( + skipna=self.skipna, numeric_only=self.numeric_only + ) + ) + + @property + def chunk_kwargs(self): + return dict(skipna=self.skipna, numeric_only=self.numeric_only) + + @property + def combine_kwargs(self): + return {} + + @property + def aggregate_kwargs(self): + return dict(ddof=self.ddof) + + @classmethod + def reduction_chunk(cls, x, skipna=True, numeric_only=False): + kwargs = {"numeric_only": numeric_only} if is_dataframe_like(x) else {} + if skipna or numeric_only: + n = x.count(**kwargs) + kwargs["skipna"] = skipna + avg = x.mean(**kwargs) + else: + # Not skipping nulls, so might as well + # avoid the full `count` operation + n = len(x) + kwargs["skipna"] = skipna + avg = x.sum(**kwargs) / n + if numeric_only: + # Workaround for cudf bug + # (see: https://github.com/rapidsai/cudf/issues/13731) + x = x[n.index] + m2 = ((x - avg) ** 2).sum(**kwargs) + return n, avg, m2 + + @classmethod + def reduction_combine(cls, parts): + n, avg, m2 = parts[0] + for i in range(1, len(parts)): + n_a, avg_a, m2_a = n, avg, m2 + n_b, avg_b, m2_b = parts[i] + n = n_a + n_b + avg = (n_a * avg_a + n_b * avg_b) / n + delta = avg_b - avg_a + m2 = m2_a + m2_b + delta**2 * n_a * n_b / n + return n, avg, m2 + + @classmethod + def reduction_aggregate(cls, vals, ddof=1): + vals = cls.reduction_combine(vals) + n, _, m2 = vals + return m2 / (n - ddof) + + +def _patched_var( + self, + axis=0, + skipna=True, + ddof=1, + numeric_only=False, + split_every=False, +): + if axis == 0: + if hasattr(self._meta, "to_pandas"): + return VarCudf(self, skipna, ddof, numeric_only, split_every) + else: + return Var(self, skipna, ddof, numeric_only, split_every) + elif axis == 1: + return VarColumns(self, skipna, ddof, numeric_only) + else: + raise ValueError(f"axis={axis} not supported. Please specify 0 or 1") + + +# Temporary work-around for missing cudf + categorical support +# See: https://github.com/rapidsai/cudf/issues/11795 +# TODO: Fix RepartitionQuantiles and remove this in cudf>24.06 + +_original_get_divisions = _shuffle_module._get_divisions + + +def _patched_get_divisions(frame, other, *args, **kwargs): + # NOTE: The following two lines contains the "patch" + # (we simply convert the partitioning column to pandas) + if is_categorical_dtype(other._meta.dtype) and hasattr( + other.frame._meta, "to_pandas" + ): + other = new_collection(other).to_backend("pandas")._expr + + # Call "original" function + return _original_get_divisions(frame, other, *args, **kwargs) + + +_PATCHED = False + + +def _patch_dask_expr(): + global _PATCHED + + if not _PATCHED: + CumulativeBlockwise._args = PatchCumulativeBlockwise._args + CumulativeBlockwise._kwargs = PatchCumulativeBlockwise._kwargs + Expr.var = _patched_var + _shuffle_module._get_divisions = _patched_get_divisions + _PATCHED = True diff --git a/python/dask_cudf/dask_cudf/_expr/groupby.py b/python/dask_cudf/dask_cudf/_expr/groupby.py new file mode 100644 index 00000000000..0242fac6e72 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_expr/groupby.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import functools + +import pandas as pd +from dask_expr._collection import new_collection +from dask_expr._groupby import ( + DecomposableGroupbyAggregation, + GroupBy as DXGroupBy, + GroupbyAggregation, + SeriesGroupBy as DXSeriesGroupBy, + SingleAggregation, +) +from dask_expr._util import is_scalar + +from dask.dataframe.core import _concat +from dask.dataframe.groupby import Aggregation + +from cudf.core.groupby.groupby import _deprecate_collect + +## +## Fused groupby aggregations +## + + +def _get_spec_info(gb): + if isinstance(gb.arg, (dict, list)): + aggs = gb.arg.copy() + else: + aggs = gb.arg + + if gb._slice and not isinstance(aggs, dict): + aggs = {gb._slice: aggs} + + gb_cols = gb._by_columns + if isinstance(gb_cols, str): + gb_cols = [gb_cols] + columns = [c for c in gb.frame.columns if c not in gb_cols] + if not isinstance(aggs, dict): + aggs = {col: aggs for col in columns} + + # Assert if our output will have a MultiIndex; this will be the case if + # any value in the `aggs` dict is not a string (i.e. multiple/named + # aggregations per column) + str_cols_out = True + aggs_renames = {} + for col in aggs: + if isinstance(aggs[col], str) or callable(aggs[col]): + aggs[col] = [aggs[col]] + elif isinstance(aggs[col], dict): + str_cols_out = False + col_aggs = [] + for k, v in aggs[col].items(): + aggs_renames[col, v] = k + col_aggs.append(v) + aggs[col] = col_aggs + else: + str_cols_out = False + if col in gb_cols: + columns.append(col) + + return { + "aggs": aggs, + "columns": columns, + "str_cols_out": str_cols_out, + "aggs_renames": aggs_renames, + } + + +def _get_meta(gb): + spec_info = gb.spec_info + gb_cols = gb._by_columns + aggs = spec_info["aggs"].copy() + aggs_renames = spec_info["aggs_renames"] + if spec_info["str_cols_out"]: + # Metadata should use `str` for dict values if that is + # what the user originally specified (column names will + # be str, rather than tuples). + for col in aggs: + aggs[col] = aggs[col][0] + _meta = gb.frame._meta.groupby(gb_cols).agg(aggs) + if aggs_renames: + col_array = [] + agg_array = [] + for col, agg in _meta.columns: + col_array.append(col) + agg_array.append(aggs_renames.get((col, agg), agg)) + _meta.columns = pd.MultiIndex.from_arrays([col_array, agg_array]) + return _meta + + +class DecomposableCudfGroupbyAgg(DecomposableGroupbyAggregation): + sep = "___" + + @functools.cached_property + def spec_info(self): + return _get_spec_info(self) + + @functools.cached_property + def _meta(self): + return _get_meta(self) + + @property + def shuffle_by_index(self): + return False # We always group by column(s) + + @classmethod + def chunk(cls, df, *by, **kwargs): + from dask_cudf._legacy.groupby import _groupby_partition_agg + + return _groupby_partition_agg(df, **kwargs) + + @classmethod + def combine(cls, inputs, **kwargs): + from dask_cudf._legacy.groupby import _tree_node_agg + + return _tree_node_agg(_concat(inputs), **kwargs) + + @classmethod + def aggregate(cls, inputs, **kwargs): + from dask_cudf._legacy.groupby import _finalize_gb_agg + + return _finalize_gb_agg(_concat(inputs), **kwargs) + + @property + def chunk_kwargs(self) -> dict: + dropna = True if self.dropna is None else self.dropna + return { + "gb_cols": self._by_columns, + "aggs": self.spec_info["aggs"], + "columns": self.spec_info["columns"], + "dropna": dropna, + "sort": self.sort, + "sep": self.sep, + } + + @property + def combine_kwargs(self) -> dict: + dropna = True if self.dropna is None else self.dropna + return { + "gb_cols": self._by_columns, + "dropna": dropna, + "sort": self.sort, + "sep": self.sep, + } + + @property + def aggregate_kwargs(self) -> dict: + dropna = True if self.dropna is None else self.dropna + final_columns = self._slice or self._meta.columns + return { + "gb_cols": self._by_columns, + "aggs": self.spec_info["aggs"], + "columns": self.spec_info["columns"], + "final_columns": final_columns, + "as_index": True, + "dropna": dropna, + "sort": self.sort, + "sep": self.sep, + "str_cols_out": self.spec_info["str_cols_out"], + "aggs_renames": self.spec_info["aggs_renames"], + } + + +class CudfGroupbyAgg(GroupbyAggregation): + @functools.cached_property + def spec_info(self): + return _get_spec_info(self) + + @functools.cached_property + def _meta(self): + return _get_meta(self) + + def _lower(self): + return DecomposableCudfGroupbyAgg( + self.frame, + self.arg, + self.observed, + self.dropna, + self.split_every, + self.split_out, + self.sort, + self.shuffle_method, + self._slice, + *self.by, + ) + + +def _maybe_get_custom_expr( + gb, + aggs, + split_every=None, + split_out=None, + shuffle_method=None, + **kwargs, +): + from dask_cudf._legacy.groupby import ( + OPTIMIZED_AGGS, + _aggs_optimized, + _redirect_aggs, + ) + + if kwargs: + # Unsupported key-word arguments + return None + + if not hasattr(gb.obj._meta, "to_pandas"): + # Not cuDF-backed data + return None + + _aggs = _redirect_aggs(aggs) + if not _aggs_optimized(_aggs, OPTIMIZED_AGGS): + # One or more aggregations are unsupported + return None + + return CudfGroupbyAgg( + gb.obj.expr, + _aggs, + gb.observed, + gb.dropna, + split_every, + split_out, + gb.sort, + shuffle_method, + gb._slice, + *gb.by, + ) + + +## +## Custom groupby classes +## + + +class ListAgg(SingleAggregation): + @staticmethod + def groupby_chunk(arg): + return arg.agg(list) + + @staticmethod + def groupby_aggregate(arg): + gb = arg.agg(list) + if gb.ndim > 1: + for col in gb.columns: + gb[col] = gb[col].list.concat() + return gb + else: + return gb.list.concat() + + +list_aggregation = Aggregation( + name="list", + chunk=ListAgg.groupby_chunk, + agg=ListAgg.groupby_aggregate, +) + + +def _translate_arg(arg): + # Helper function to translate args so that + # they can be processed correctly by upstream + # dask & dask-expr. Right now, the only necessary + # translation is list aggregations. + if isinstance(arg, dict): + return {k: _translate_arg(v) for k, v in arg.items()} + elif isinstance(arg, list): + return [_translate_arg(x) for x in arg] + elif arg in ("collect", "list", list): + return list_aggregation + else: + return arg + + +# We define our own GroupBy classes in Dask cuDF for +# the following reasons: +# (1) We want to use a custom `aggregate` algorithm +# that performs multiple aggregations on the +# same dataframe partition at once. The upstream +# algorithm breaks distinct aggregations into +# separate tasks. +# (2) We need to work around missing `observed=False` +# support: +# https://github.com/rapidsai/cudf/issues/15173 + + +class GroupBy(DXGroupBy): + def __init__(self, *args, observed=None, **kwargs): + observed = observed if observed is not None else True + super().__init__(*args, observed=observed, **kwargs) + + def __getitem__(self, key): + if is_scalar(key): + return SeriesGroupBy( + self.obj, + by=self.by, + slice=key, + sort=self.sort, + dropna=self.dropna, + observed=self.observed, + ) + g = GroupBy( + self.obj, + by=self.by, + slice=key, + sort=self.sort, + dropna=self.dropna, + observed=self.observed, + group_keys=self.group_keys, + ) + return g + + def collect(self, **kwargs): + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) + + def aggregate(self, arg, fused=True, **kwargs): + if ( + fused + and (expr := _maybe_get_custom_expr(self, arg, **kwargs)) + is not None + ): + return new_collection(expr) + else: + return super().aggregate(_translate_arg(arg), **kwargs) + + +class SeriesGroupBy(DXSeriesGroupBy): + def __init__(self, *args, observed=None, **kwargs): + observed = observed if observed is not None else True + super().__init__(*args, observed=observed, **kwargs) + + def collect(self, **kwargs): + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) + + def aggregate(self, arg, **kwargs): + return super().aggregate(_translate_arg(arg), **kwargs) diff --git a/python/dask_cudf/dask_cudf/_legacy/__init__.py b/python/dask_cudf/dask_cudf/_legacy/__init__.py new file mode 100644 index 00000000000..3c827d4ff59 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. diff --git a/python/dask_cudf/dask_cudf/_legacy/core.py b/python/dask_cudf/dask_cudf/_legacy/core.py new file mode 100644 index 00000000000..d6beb775a5e --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/core.py @@ -0,0 +1,711 @@ +# Copyright (c) 2018-2024, NVIDIA CORPORATION. + +import math +import warnings + +import numpy as np +import pandas as pd +from tlz import partition_all + +from dask import dataframe as dd +from dask.base import normalize_token, tokenize +from dask.dataframe.core import ( + Scalar, + handle_out, + make_meta as dask_make_meta, + map_partitions, +) +from dask.dataframe.utils import raise_on_meta_error +from dask.highlevelgraph import HighLevelGraph +from dask.utils import M, OperatorMethodMixin, apply, derived_from, funcname + +import cudf +from cudf import _lib as libcudf +from cudf.utils.performance_tracking import _dask_cudf_performance_tracking + +from dask_cudf._expr.accessors import ListMethods, StructMethods +from dask_cudf._legacy import sorting +from dask_cudf._legacy.sorting import ( + _deprecate_shuffle_kwarg, + _get_shuffle_method, +) + + +class _Frame(dd.core._Frame, OperatorMethodMixin): + """Superclass for DataFrame and Series + + Parameters + ---------- + dsk : dict + The dask graph to compute this DataFrame + name : str + The key prefix that specifies which keys in the dask comprise this + particular DataFrame / Series + meta : cudf.DataFrame, cudf.Series, or cudf.Index + An empty cudf object with names, dtypes, and indices matching the + expected output. + divisions : tuple of index values + Values along which we partition our blocks on the index + """ + + def _is_partition_type(self, meta): + return isinstance(meta, self._partition_type) + + def __repr__(self): + s = "" + return s % (type(self).__name__, len(self.dask), self.npartitions) + + +normalize_token.register(_Frame, lambda a: a._name) + + +class DataFrame(_Frame, dd.core.DataFrame): + """ + A distributed Dask DataFrame where the backing dataframe is a + :class:`cuDF DataFrame `. + + Typically you would not construct this object directly, but rather + use one of Dask-cuDF's IO routines. + + Most operations on :doc:`Dask DataFrames ` are + supported, with many of the same caveats. + + """ + + _partition_type = cudf.DataFrame + + @_dask_cudf_performance_tracking + def _assign_column(self, k, v): + def assigner(df, k, v): + out = df.copy() + out[k] = v + return out + + meta = assigner(self._meta, k, dask_make_meta(v)) + return self.map_partitions(assigner, k, v, meta=meta) + + @_dask_cudf_performance_tracking + def apply_rows(self, func, incols, outcols, kwargs=None, cache_key=None): + import uuid + + if kwargs is None: + kwargs = {} + + if cache_key is None: + cache_key = uuid.uuid4() + + def do_apply_rows(df, func, incols, outcols, kwargs): + return df.apply_rows( + func, incols, outcols, kwargs, cache_key=cache_key + ) + + meta = do_apply_rows(self._meta, func, incols, outcols, kwargs) + return self.map_partitions( + do_apply_rows, func, incols, outcols, kwargs, meta=meta + ) + + @_deprecate_shuffle_kwarg + @_dask_cudf_performance_tracking + def merge(self, other, shuffle_method=None, **kwargs): + on = kwargs.pop("on", None) + if isinstance(on, tuple): + on = list(on) + return super().merge( + other, + on=on, + shuffle_method=_get_shuffle_method(shuffle_method), + **kwargs, + ) + + @_deprecate_shuffle_kwarg + @_dask_cudf_performance_tracking + def join(self, other, shuffle_method=None, **kwargs): + # CuDF doesn't support "right" join yet + how = kwargs.pop("how", "left") + if how == "right": + return other.join(other=self, how="left", **kwargs) + + on = kwargs.pop("on", None) + if isinstance(on, tuple): + on = list(on) + return super().join( + other, + how=how, + on=on, + shuffle_method=_get_shuffle_method(shuffle_method), + **kwargs, + ) + + @_deprecate_shuffle_kwarg + @_dask_cudf_performance_tracking + def set_index( + self, + other, + sorted=False, + divisions=None, + shuffle_method=None, + **kwargs, + ): + pre_sorted = sorted + del sorted + + if divisions == "quantile": + warnings.warn( + "Using divisions='quantile' is now deprecated. " + "Please raise an issue on github if you believe " + "this feature is necessary.", + FutureWarning, + ) + + if ( + divisions == "quantile" + or isinstance(divisions, (cudf.DataFrame, cudf.Series)) + or ( + isinstance(other, str) + and cudf.api.types.is_string_dtype(self[other].dtype) + ) + ): + # Let upstream-dask handle "pre-sorted" case + if pre_sorted: + return dd.shuffle.set_sorted_index( + self, other, divisions=divisions, **kwargs + ) + + by = other + if not isinstance(other, list): + by = [by] + if len(by) > 1: + raise ValueError("Dask does not support MultiIndex (yet).") + if divisions == "quantile": + divisions = None + + # Use dask_cudf's sort_values + df = self.sort_values( + by, + max_branch=kwargs.get("max_branch", None), + divisions=divisions, + set_divisions=True, + ignore_index=True, + shuffle_method=shuffle_method, + ) + + # Ignore divisions if its a dataframe + if isinstance(divisions, cudf.DataFrame): + divisions = None + + # Set index and repartition + df2 = df.map_partitions( + sorting.set_index_post, + index_name=other, + drop=kwargs.get("drop", True), + column_dtype=df.columns.dtype, + ) + npartitions = kwargs.get("npartitions", self.npartitions) + partition_size = kwargs.get("partition_size", None) + if partition_size: + return df2.repartition(partition_size=partition_size) + if not divisions and df2.npartitions != npartitions: + return df2.repartition(npartitions=npartitions) + if divisions and df2.npartitions != len(divisions) - 1: + return df2.repartition(divisions=divisions) + return df2 + + return super().set_index( + other, + sorted=pre_sorted, + shuffle_method=_get_shuffle_method(shuffle_method), + divisions=divisions, + **kwargs, + ) + + @_deprecate_shuffle_kwarg + @_dask_cudf_performance_tracking + def sort_values( + self, + by, + ignore_index=False, + max_branch=None, + divisions=None, + set_divisions=False, + ascending=True, + na_position="last", + sort_function=None, + sort_function_kwargs=None, + shuffle_method=None, + **kwargs, + ): + if kwargs: + raise ValueError( + f"Unsupported input arguments passed : {list(kwargs.keys())}" + ) + + df = sorting.sort_values( + self, + by, + max_branch=max_branch, + divisions=divisions, + set_divisions=set_divisions, + ignore_index=ignore_index, + ascending=ascending, + na_position=na_position, + shuffle_method=shuffle_method, + sort_function=sort_function, + sort_function_kwargs=sort_function_kwargs, + ) + + if ignore_index: + return df.reset_index(drop=True) + return df + + @_dask_cudf_performance_tracking + def to_parquet(self, path, *args, **kwargs): + """Calls dask.dataframe.io.to_parquet with CudfEngine backend""" + from dask_cudf._legacy.io import to_parquet + + return to_parquet(self, path, *args, **kwargs) + + @_dask_cudf_performance_tracking + def to_orc(self, path, **kwargs): + """Calls dask_cudf._legacy.io.to_orc""" + from dask_cudf._legacy.io import to_orc + + return to_orc(self, path, **kwargs) + + @derived_from(pd.DataFrame) + @_dask_cudf_performance_tracking + def var( + self, + axis=None, + skipna=True, + ddof=1, + split_every=False, + dtype=None, + out=None, + naive=False, + numeric_only=False, + ): + axis = self._validate_axis(axis) + meta = self._meta_nonempty.var( + axis=axis, skipna=skipna, numeric_only=numeric_only + ) + if axis == 1: + result = map_partitions( + M.var, + self, + meta=meta, + token=self._token_prefix + "var", + axis=axis, + skipna=skipna, + ddof=ddof, + numeric_only=numeric_only, + ) + return handle_out(out, result) + elif naive: + return _naive_var(self, meta, skipna, ddof, split_every, out) + else: + return _parallel_var(self, meta, skipna, split_every, out) + + @_deprecate_shuffle_kwarg + @_dask_cudf_performance_tracking + def shuffle(self, *args, shuffle_method=None, **kwargs): + """Wraps dask.dataframe DataFrame.shuffle method""" + return super().shuffle( + *args, shuffle_method=_get_shuffle_method(shuffle_method), **kwargs + ) + + @_dask_cudf_performance_tracking + def groupby(self, by=None, **kwargs): + from .groupby import CudfDataFrameGroupBy + + return CudfDataFrameGroupBy(self, by=by, **kwargs) + + +@_dask_cudf_performance_tracking +def sum_of_squares(x): + x = x.astype("f8")._column + outcol = libcudf.reduce.reduce("sum_of_squares", x) + return cudf.Series._from_column(outcol) + + +@_dask_cudf_performance_tracking +def var_aggregate(x2, x, n, ddof): + try: + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + result = (x2 / n) - (x / n) ** 2 + if ddof != 0: + result = result * n / (n - ddof) + return result + except ZeroDivisionError: + return np.float64(np.nan) + + +@_dask_cudf_performance_tracking +def nlargest_agg(x, **kwargs): + return cudf.concat(x).nlargest(**kwargs) + + +@_dask_cudf_performance_tracking +def nsmallest_agg(x, **kwargs): + return cudf.concat(x).nsmallest(**kwargs) + + +class Series(_Frame, dd.core.Series): + _partition_type = cudf.Series + + @_dask_cudf_performance_tracking + def count(self, split_every=False): + return reduction( + [self], + chunk=M.count, + aggregate=np.sum, + split_every=split_every, + meta="i8", + ) + + @_dask_cudf_performance_tracking + def mean(self, split_every=False): + sum = self.sum(split_every=split_every) + n = self.count(split_every=split_every) + return sum / n + + @derived_from(pd.DataFrame) + @_dask_cudf_performance_tracking + def var( + self, + axis=None, + skipna=True, + ddof=1, + split_every=False, + dtype=None, + out=None, + naive=False, + ): + axis = self._validate_axis(axis) + meta = self._meta_nonempty.var(axis=axis, skipna=skipna) + if axis == 1: + result = map_partitions( + M.var, + self, + meta=meta, + token=self._token_prefix + "var", + axis=axis, + skipna=skipna, + ddof=ddof, + ) + return handle_out(out, result) + elif naive: + return _naive_var(self, meta, skipna, ddof, split_every, out) + else: + return _parallel_var(self, meta, skipna, split_every, out) + + @_dask_cudf_performance_tracking + def groupby(self, *args, **kwargs): + from .groupby import CudfSeriesGroupBy + + return CudfSeriesGroupBy(self, *args, **kwargs) + + @property # type: ignore + @_dask_cudf_performance_tracking + def list(self): + return ListMethods(self) + + @property # type: ignore + @_dask_cudf_performance_tracking + def struct(self): + return StructMethods(self) + + +class Index(Series, dd.core.Index): + _partition_type = cudf.Index # type: ignore + + +@_dask_cudf_performance_tracking +def _naive_var(ddf, meta, skipna, ddof, split_every, out): + num = ddf._get_numeric_data() + x = 1.0 * num.sum(skipna=skipna, split_every=split_every) + x2 = 1.0 * (num**2).sum(skipna=skipna, split_every=split_every) + n = num.count(split_every=split_every) + name = ddf._token_prefix + "var" + result = map_partitions( + var_aggregate, x2, x, n, token=name, meta=meta, ddof=ddof + ) + if isinstance(ddf, DataFrame): + result.divisions = (min(ddf.columns), max(ddf.columns)) + return handle_out(out, result) + + +@_dask_cudf_performance_tracking +def _parallel_var(ddf, meta, skipna, split_every, out): + def _local_var(x, skipna): + if skipna: + n = x.count() + avg = x.mean(skipna=skipna) + else: + # Not skipping nulls, so might as well + # avoid the full `count` operation + n = len(x) + avg = x.sum(skipna=skipna) / n + m2 = ((x - avg) ** 2).sum(skipna=skipna) + return n, avg, m2 + + def _aggregate_var(parts): + n, avg, m2 = parts[0] + for i in range(1, len(parts)): + n_a, avg_a, m2_a = n, avg, m2 + n_b, avg_b, m2_b = parts[i] + n = n_a + n_b + avg = (n_a * avg_a + n_b * avg_b) / n + delta = avg_b - avg_a + m2 = m2_a + m2_b + delta**2 * n_a * n_b / n + return n, avg, m2 + + def _finalize_var(vals): + n, _, m2 = vals + return m2 / (n - 1) + + # Build graph + nparts = ddf.npartitions + if not split_every: + split_every = nparts + name = "var-" + tokenize(skipna, split_every, out) + local_name = "local-" + name + num = ddf._get_numeric_data() + dsk = { + (local_name, n, 0): (_local_var, (num._name, n), skipna) + for n in range(nparts) + } + + # Use reduction tree + widths = [nparts] + while nparts > 1: + nparts = math.ceil(nparts / split_every) + widths.append(nparts) + height = len(widths) + for depth in range(1, height): + for group in range(widths[depth]): + p_max = widths[depth - 1] + lstart = split_every * group + lstop = min(lstart + split_every, p_max) + node_list = [ + (local_name, p, depth - 1) for p in range(lstart, lstop) + ] + dsk[(local_name, group, depth)] = (_aggregate_var, node_list) + if height == 1: + group = depth = 0 + dsk[(name, 0)] = (_finalize_var, (local_name, group, depth)) + + graph = HighLevelGraph.from_collections(name, dsk, dependencies=[num, ddf]) + result = dd.core.new_dd_object(graph, name, meta, (None, None)) + if isinstance(ddf, DataFrame): + result.divisions = (min(ddf.columns), max(ddf.columns)) + return handle_out(out, result) + + +@_dask_cudf_performance_tracking +def _extract_meta(x): + """ + Extract internal cache data (``_meta``) from dask_cudf objects + """ + if isinstance(x, (Scalar, _Frame)): + return x._meta + elif isinstance(x, list): + return [_extract_meta(_x) for _x in x] + elif isinstance(x, tuple): + return tuple(_extract_meta(_x) for _x in x) + elif isinstance(x, dict): + return {k: _extract_meta(v) for k, v in x.items()} + return x + + +@_dask_cudf_performance_tracking +def _emulate(func, *args, **kwargs): + """ + Apply a function using args / kwargs. If arguments contain dd.DataFrame / + dd.Series, using internal cache (``_meta``) for calculation + """ + with raise_on_meta_error(funcname(func)): + return func(*_extract_meta(args), **_extract_meta(kwargs)) + + +@_dask_cudf_performance_tracking +def align_partitions(args): + """Align partitions between dask_cudf objects. + + Note that if all divisions are unknown, but have equal npartitions, then + they will be passed through unchanged. + """ + dfs = [df for df in args if isinstance(df, _Frame)] + if not dfs: + return args + + divisions = dfs[0].divisions + if not all(df.divisions == divisions for df in dfs): + raise NotImplementedError("Aligning mismatched partitions") + return args + + +@_dask_cudf_performance_tracking +def reduction( + args, + chunk=None, + aggregate=None, + combine=None, + meta=None, + token=None, + chunk_kwargs=None, + aggregate_kwargs=None, + combine_kwargs=None, + split_every=None, + **kwargs, +): + """Generic tree reduction operation. + + Parameters + ---------- + args : + Positional arguments for the `chunk` function. All `dask.dataframe` + objects should be partitioned and indexed equivalently. + chunk : function [block-per-arg] -> block + Function to operate on each block of data + aggregate : function list-of-blocks -> block + Function to operate on the list of results of chunk + combine : function list-of-blocks -> block, optional + Function to operate on intermediate lists of results of chunk + in a tree-reduction. If not provided, defaults to aggregate. + $META + token : str, optional + The name to use for the output keys. + chunk_kwargs : dict, optional + Keywords for the chunk function only. + aggregate_kwargs : dict, optional + Keywords for the aggregate function only. + combine_kwargs : dict, optional + Keywords for the combine function only. + split_every : int, optional + Group partitions into groups of this size while performing a + tree-reduction. If set to False, no tree-reduction will be used, + and all intermediates will be concatenated and passed to ``aggregate``. + Default is 8. + kwargs : + All remaining keywords will be passed to ``chunk``, ``aggregate``, and + ``combine``. + """ + if chunk_kwargs is None: + chunk_kwargs = dict() + if aggregate_kwargs is None: + aggregate_kwargs = dict() + chunk_kwargs.update(kwargs) + aggregate_kwargs.update(kwargs) + + if combine is None: + if combine_kwargs: + raise ValueError("`combine_kwargs` provided with no `combine`") + combine = aggregate + combine_kwargs = aggregate_kwargs + else: + if combine_kwargs is None: + combine_kwargs = dict() + combine_kwargs.update(kwargs) + + if not isinstance(args, (tuple, list)): + args = [args] + + npartitions = {arg.npartitions for arg in args if isinstance(arg, _Frame)} + if len(npartitions) > 1: + raise ValueError("All arguments must have same number of partitions") + npartitions = npartitions.pop() + + if split_every is None: + split_every = 8 + elif split_every is False: + split_every = npartitions + elif split_every < 2 or not isinstance(split_every, int): + raise ValueError("split_every must be an integer >= 2") + + token_key = tokenize( + token or (chunk, aggregate), + meta, + args, + chunk_kwargs, + aggregate_kwargs, + combine_kwargs, + split_every, + ) + + # Chunk + a = f"{token or funcname(chunk)}-chunk-{token_key}" + if len(args) == 1 and isinstance(args[0], _Frame) and not chunk_kwargs: + dsk = { + (a, 0, i): (chunk, key) + for i, key in enumerate(args[0].__dask_keys__()) + } + else: + dsk = { + (a, 0, i): ( + apply, + chunk, + [(x._name, i) if isinstance(x, _Frame) else x for x in args], + chunk_kwargs, + ) + for i in range(args[0].npartitions) + } + + # Combine + b = f"{token or funcname(combine)}-combine-{token_key}" + k = npartitions + depth = 0 + while k > split_every: + for part_i, inds in enumerate(partition_all(split_every, range(k))): + conc = (list, [(a, depth, i) for i in inds]) + dsk[(b, depth + 1, part_i)] = ( + (apply, combine, [conc], combine_kwargs) + if combine_kwargs + else (combine, conc) + ) + k = part_i + 1 + a = b + depth += 1 + + # Aggregate + b = f"{token or funcname(aggregate)}-agg-{token_key}" + conc = (list, [(a, depth, i) for i in range(k)]) + if aggregate_kwargs: + dsk[(b, 0)] = (apply, aggregate, [conc], aggregate_kwargs) + else: + dsk[(b, 0)] = (aggregate, conc) + + if meta is None: + meta_chunk = _emulate(apply, chunk, args, chunk_kwargs) + meta = _emulate(apply, aggregate, [[meta_chunk]], aggregate_kwargs) + meta = dask_make_meta(meta) + + graph = HighLevelGraph.from_collections(b, dsk, dependencies=args) + return dd.core.new_dd_object(graph, b, meta, (None, None)) + + +for name in ( + "add", + "sub", + "mul", + "truediv", + "floordiv", + "mod", + "pow", + "radd", + "rsub", + "rmul", + "rtruediv", + "rfloordiv", + "rmod", + "rpow", +): + meth = getattr(cudf.DataFrame, name) + DataFrame._bind_operator_method(name, meth, original=cudf.Series) + + meth = getattr(cudf.Series, name) + Series._bind_operator_method(name, meth, original=cudf.Series) + +for name in ("lt", "gt", "le", "ge", "ne", "eq"): + meth = getattr(cudf.Series, name) + Series._bind_comparison_method(name, meth, original=cudf.Series) diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/_legacy/groupby.py similarity index 99% rename from python/dask_cudf/dask_cudf/groupby.py rename to python/dask_cudf/dask_cudf/_legacy/groupby.py index bbbcde17b51..7e01e91476d 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/_legacy/groupby.py @@ -18,7 +18,7 @@ from cudf.core.groupby.groupby import _deprecate_collect from cudf.utils.performance_tracking import _dask_cudf_performance_tracking -from dask_cudf.sorting import _deprecate_shuffle_kwarg +from dask_cudf._legacy.sorting import _deprecate_shuffle_kwarg # aggregations that are dask-cudf optimized OPTIMIZED_AGGS = ( diff --git a/python/dask_cudf/dask_cudf/_legacy/io/__init__.py b/python/dask_cudf/dask_cudf/_legacy/io/__init__.py new file mode 100644 index 00000000000..0421bd755f4 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2018-2024, NVIDIA CORPORATION. + +from .csv import read_csv # noqa: F401 +from .json import read_json # noqa: F401 +from .orc import read_orc, to_orc # noqa: F401 +from .text import read_text # noqa: F401 + +try: + from .parquet import read_parquet, to_parquet # noqa: F401 +except ImportError: + pass diff --git a/python/dask_cudf/dask_cudf/_legacy/io/csv.py b/python/dask_cudf/dask_cudf/_legacy/io/csv.py new file mode 100644 index 00000000000..fa5400344f9 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/csv.py @@ -0,0 +1,222 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. + +import os +from glob import glob +from warnings import warn + +from fsspec.utils import infer_compression + +from dask import dataframe as dd +from dask.base import tokenize +from dask.dataframe.io.csv import make_reader +from dask.utils import apply, parse_bytes + +import cudf + + +def read_csv(path, blocksize="default", **kwargs): + """ + Read CSV files into a :class:`.DataFrame`. + + This API parallelizes the :func:`cudf:cudf.read_csv` function in + the following ways: + + It supports loading many files at once using globstrings: + + >>> import dask_cudf + >>> df = dask_cudf.read_csv("myfiles.*.csv") + + In some cases it can break up large files: + + >>> df = dask_cudf.read_csv("largefile.csv", blocksize="256 MiB") + + It can read CSV files from external resources (e.g. S3, HTTP, FTP) + + >>> df = dask_cudf.read_csv("s3://bucket/myfiles.*.csv") + >>> df = dask_cudf.read_csv("https://www.mycloud.com/sample.csv") + + Internally ``read_csv`` uses :func:`cudf:cudf.read_csv` and + supports many of the same keyword arguments with the same + performance guarantees. See the docstring for + :func:`cudf:cudf.read_csv` for more information on available + keyword arguments. + + Parameters + ---------- + path : str, path object, or file-like object + Either a path to a file (a str, :py:class:`pathlib.Path`, or + py._path.local.LocalPath), URL (including http, ftp, and S3 + locations), or any object with a read() method (such as + builtin :py:func:`open` file handler function or + :py:class:`~io.StringIO`). + blocksize : int or str, default "256 MiB" + The target task partition size. If ``None``, a single block + is used for each file. + **kwargs : dict + Passthrough key-word arguments that are sent to + :func:`cudf:cudf.read_csv`. + + Notes + ----- + If any of `skipfooter`/`skiprows`/`nrows` are passed, + `blocksize` will default to None. + + Examples + -------- + >>> import dask_cudf + >>> ddf = dask_cudf.read_csv("sample.csv", usecols=["a", "b"]) + >>> ddf.compute() + a b + 0 1 hi + 1 2 hello + 2 3 ai + + """ + + # Handle `chunksize` deprecation + if "chunksize" in kwargs: + chunksize = kwargs.pop("chunksize", "default") + warn( + "`chunksize` is deprecated and will be removed in the future. " + "Please use `blocksize` instead.", + FutureWarning, + ) + if blocksize == "default": + blocksize = chunksize + + # Set default `blocksize` + if blocksize == "default": + if ( + kwargs.get("skipfooter", 0) != 0 + or kwargs.get("skiprows", 0) != 0 + or kwargs.get("nrows", None) is not None + ): + # Cannot read in blocks if skipfooter, + # skiprows or nrows is passed. + blocksize = None + else: + blocksize = "256 MiB" + + if "://" in str(path): + func = make_reader(cudf.read_csv, "read_csv", "CSV") + return func(path, blocksize=blocksize, **kwargs) + else: + return _internal_read_csv(path=path, blocksize=blocksize, **kwargs) + + +def _internal_read_csv(path, blocksize="256 MiB", **kwargs): + if isinstance(blocksize, str): + blocksize = parse_bytes(blocksize) + + if isinstance(path, list): + filenames = path + elif isinstance(path, str): + filenames = sorted(glob(path)) + elif hasattr(path, "__fspath__"): + filenames = sorted(glob(path.__fspath__())) + else: + raise TypeError(f"Path type not understood:{type(path)}") + + if not filenames: + msg = f"A file in: {filenames} does not exist." + raise FileNotFoundError(msg) + + name = "read-csv-" + tokenize( + path, tokenize, **kwargs + ) # TODO: get last modified time + + compression = kwargs.get("compression", "infer") + + if compression == "infer": + # Infer compression from first path by default + compression = infer_compression(filenames[0]) + + if compression and blocksize: + # compressed CSVs reading must read the entire file + kwargs.pop("byte_range", None) + warn( + "Warning %s compression does not support breaking apart files\n" + "Please ensure that each individual file can fit in memory and\n" + "use the keyword ``blocksize=None to remove this message``\n" + "Setting ``blocksize=(size of file)``" % compression + ) + blocksize = None + + if blocksize is None: + return read_csv_without_blocksize(path, **kwargs) + + # Let dask.dataframe generate meta + dask_reader = make_reader(cudf.read_csv, "read_csv", "CSV") + kwargs1 = kwargs.copy() + usecols = kwargs1.pop("usecols", None) + dtype = kwargs1.pop("dtype", None) + meta = dask_reader(filenames[0], **kwargs1)._meta + names = meta.columns + if usecols or dtype: + # Regenerate meta with original kwargs if + # `usecols` or `dtype` was specified + meta = dask_reader(filenames[0], **kwargs)._meta + + dsk = {} + i = 0 + dtypes = meta.dtypes.values + + for fn in filenames: + size = os.path.getsize(fn) + for start in range(0, size, blocksize): + kwargs2 = kwargs.copy() + kwargs2["byte_range"] = ( + start, + blocksize, + ) # specify which chunk of the file we care about + if start != 0: + kwargs2["names"] = names # no header in the middle of the file + kwargs2["header"] = None + dsk[(name, i)] = (apply, _read_csv, [fn, dtypes], kwargs2) + + i += 1 + + divisions = [None] * (len(dsk) + 1) + return dd.core.new_dd_object(dsk, name, meta, divisions) + + +def _read_csv(fn, dtypes=None, **kwargs): + return cudf.read_csv(fn, **kwargs) + + +def read_csv_without_blocksize(path, **kwargs): + """Read entire CSV with optional compression (gzip/zip) + + Parameters + ---------- + path : str + path to files (support for glob) + """ + if isinstance(path, list): + filenames = path + elif isinstance(path, str): + filenames = sorted(glob(path)) + elif hasattr(path, "__fspath__"): + filenames = sorted(glob(path.__fspath__())) + else: + raise TypeError(f"Path type not understood:{type(path)}") + + name = "read-csv-" + tokenize(path, **kwargs) + + meta_kwargs = kwargs.copy() + if "skipfooter" in meta_kwargs: + meta_kwargs.pop("skipfooter") + if "nrows" in meta_kwargs: + meta_kwargs.pop("nrows") + # Read "head" of first file (first 5 rows). + # Convert to empty df for metadata. + meta = cudf.read_csv(filenames[0], nrows=5, **meta_kwargs).iloc[:0] + + graph = { + (name, i): (apply, cudf.read_csv, [fn], kwargs) + for i, fn in enumerate(filenames) + } + + divisions = [None] * (len(filenames) + 1) + + return dd.core.new_dd_object(graph, name, meta, divisions) diff --git a/python/dask_cudf/dask_cudf/_legacy/io/json.py b/python/dask_cudf/dask_cudf/_legacy/io/json.py new file mode 100644 index 00000000000..98c5ceedb76 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/json.py @@ -0,0 +1,209 @@ +# Copyright (c) 2019-2024, NVIDIA CORPORATION. + +from functools import partial + +import numpy as np +from fsspec.core import get_compression, get_fs_token_paths + +import dask +from dask.utils import parse_bytes + +import cudf +from cudf.core.column import as_column +from cudf.utils.ioutils import _is_local_filesystem + +from dask_cudf.backends import _default_backend + + +def _read_json_partition( + paths, + fs=None, + include_path_column=False, + path_converter=None, + **kwargs, +): + # Transfer all data up front for remote storage + sources = ( + paths + if fs is None + else fs.cat_ranges( + paths, + [0] * len(paths), + fs.sizes(paths), + ) + ) + + if include_path_column: + # Add "path" column. + # Must iterate over sources sequentially + if not isinstance(include_path_column, str): + include_path_column = "path" + converted_paths = ( + paths + if path_converter is None + else [path_converter(path) for path in paths] + ) + dfs = [] + for i, source in enumerate(sources): + df = cudf.read_json(source, **kwargs) + df[include_path_column] = as_column( + converted_paths[i], length=len(df) + ) + dfs.append(df) + return cudf.concat(dfs) + else: + # Pass sources directly to cudf + return cudf.read_json(sources, **kwargs) + + +def read_json( + url_path, + engine="auto", + blocksize=None, + orient="records", + lines=None, + compression="infer", + aggregate_files=True, + **kwargs, +): + """Read JSON data into a :class:`.DataFrame`. + + This function wraps :func:`dask.dataframe.read_json`, and passes + ``engine=partial(cudf.read_json, engine="auto")`` by default. + + Parameters + ---------- + url_path : str, list of str + Location to read from. If a string, can include a glob character to + find a set of file names. + Supports protocol specifications such as ``"s3://"``. + engine : str or Callable, default "auto" + + If str, this value will be used as the ``engine`` argument + when :func:`cudf.read_json` is used to create each partition. + If a :obj:`~collections.abc.Callable`, this value will be used as the + underlying function used to create each partition from JSON + data. The default value is "auto", so that + ``engine=partial(cudf.read_json, engine="auto")`` will be + passed to :func:`dask.dataframe.read_json` by default. + aggregate_files : bool or int + Whether to map multiple files to each output partition. If True, + the `blocksize` argument will be used to determine the number of + files in each partition. If any one file is larger than `blocksize`, + the `aggregate_files` argument will be ignored. If an integer value + is specified, the `blocksize` argument will be ignored, and that + number of files will be mapped to each partition. Default is True. + **kwargs : + Key-word arguments to pass through to :func:`dask.dataframe.read_json`. + + Returns + ------- + :class:`.DataFrame` + + Examples + -------- + Load single file + + >>> from dask_cudf import read_json + >>> read_json('myfile.json') # doctest: +SKIP + + Load large line-delimited JSON files using partitions of approx + 256MB size + + >>> read_json('data/file*.csv', blocksize=2**28) # doctest: +SKIP + + Load nested JSON data + + >>> read_json('myfile.json') # doctest: +SKIP + + See Also + -------- + dask.dataframe.read_json + + """ + + if lines is None: + lines = orient == "records" + if orient != "records" and lines: + raise ValueError( + 'Line-delimited JSON is only available with orient="records".' + ) + if blocksize and (orient != "records" or not lines): + raise ValueError( + "JSON file chunking only allowed for JSON-lines" + "input (orient='records', lines=True)." + ) + + inputs = [] + if aggregate_files and blocksize or int(aggregate_files) > 1: + # Attempt custom read if we are mapping multiple files + # to each output partition. Otherwise, upstream logic + # is sufficient. + + storage_options = kwargs.get("storage_options", {}) + fs, _, paths = get_fs_token_paths( + url_path, mode="rb", storage_options=storage_options + ) + if isinstance(aggregate_files, int) and aggregate_files > 1: + # Map a static file count to each partition + inputs = [ + paths[offset : offset + aggregate_files] + for offset in range(0, len(paths), aggregate_files) + ] + elif aggregate_files is True and blocksize: + # Map files dynamically (using blocksize) + file_sizes = fs.sizes(paths) # NOTE: This can be slow + blocksize = parse_bytes(blocksize) + if all([file_size <= blocksize for file_size in file_sizes]): + counts = np.unique( + np.floor(np.cumsum(file_sizes) / blocksize), + return_counts=True, + )[1] + offsets = np.concatenate([[0], counts.cumsum()]) + inputs = [ + paths[offsets[i] : offsets[i + 1]] + for i in range(len(offsets) - 1) + ] + + if inputs: + # Inputs were successfully populated. + # Use custom _read_json_partition function + # to generate each partition. + + compression = get_compression( + url_path[0] if isinstance(url_path, list) else url_path, + compression, + ) + _kwargs = dict( + orient=orient, + lines=lines, + compression=compression, + include_path_column=kwargs.get("include_path_column", False), + path_converter=kwargs.get("path_converter"), + ) + if not _is_local_filesystem(fs): + _kwargs["fs"] = fs + # TODO: Generate meta more efficiently + meta = _read_json_partition(inputs[0][:1], **_kwargs) + return dask.dataframe.from_map( + _read_json_partition, + inputs, + meta=meta, + **_kwargs, + ) + + # Fall back to dask.dataframe.read_json + return _default_backend( + dask.dataframe.read_json, + url_path, + engine=( + partial(cudf.read_json, engine=engine) + if isinstance(engine, str) + else engine + ), + blocksize=blocksize, + orient=orient, + lines=lines, + compression=compression, + **kwargs, + ) diff --git a/python/dask_cudf/dask_cudf/_legacy/io/orc.py b/python/dask_cudf/dask_cudf/_legacy/io/orc.py new file mode 100644 index 00000000000..bed69f038b0 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/orc.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020-2024, NVIDIA CORPORATION. + +from io import BufferedWriter, IOBase + +from fsspec.core import get_fs_token_paths +from fsspec.utils import stringify_path +from pyarrow import orc as orc + +from dask import dataframe as dd +from dask.base import tokenize +from dask.dataframe.io.utils import _get_pyarrow_dtypes + +import cudf + + +def _read_orc_stripe(fs, path, stripe, columns, kwargs=None): + """Pull out specific columns from specific stripe""" + if kwargs is None: + kwargs = {} + with fs.open(path, "rb") as f: + df_stripe = cudf.read_orc( + f, stripes=[stripe], columns=columns, **kwargs + ) + return df_stripe + + +def read_orc(path, columns=None, filters=None, storage_options=None, **kwargs): + """Read ORC files into a :class:`.DataFrame`. + + Note that this function is mostly borrowed from upstream Dask. + + Parameters + ---------- + path : str or list[str] + Location of file(s), which can be a full URL with protocol specifier, + and may include glob character if a single string. + columns : None or list[str] + Columns to load. If None, loads all. + filters : None or list of tuple or list of lists of tuples + If not None, specifies a filter predicate used to filter out + row groups using statistics stored for each row group as + Parquet metadata. Row groups that do not match the given + filter predicate are not read. The predicate is expressed in + `disjunctive normal form (DNF) + `__ + like ``[[('x', '=', 0), ...], ...]``. DNF allows arbitrary + boolean logical combinations of single column predicates. The + innermost tuples each describe a single column predicate. The + list of inner predicates is interpreted as a conjunction + (AND), forming a more selective and multiple column predicate. + Finally, the outermost list combines these filters as a + disjunction (OR). Predicates may also be passed as a list of + tuples. This form is interpreted as a single conjunction. To + express OR in predicates, one must use the (preferred) + notation of list of lists of tuples. + storage_options : None or dict + Further parameters to pass to the bytes backend. + + See Also + -------- + dask.dataframe.read_orc + + Returns + ------- + dask_cudf.DataFrame + + """ + + storage_options = storage_options or {} + fs, fs_token, paths = get_fs_token_paths( + path, mode="rb", storage_options=storage_options + ) + schema = None + nstripes_per_file = [] + for path in paths: + with fs.open(path, "rb") as f: + o = orc.ORCFile(f) + if schema is None: + schema = o.schema + elif schema != o.schema: + raise ValueError( + "Incompatible schemas while parsing ORC files" + ) + nstripes_per_file.append(o.nstripes) + schema = _get_pyarrow_dtypes(schema, categories=None) + if columns is not None: + ex = set(columns) - set(schema) + if ex: + raise ValueError( + f"Requested columns ({ex}) not in schema ({set(schema)})" + ) + else: + columns = list(schema) + + with fs.open(paths[0], "rb") as f: + meta = cudf.read_orc( + f, + stripes=[0] if nstripes_per_file[0] else None, + columns=columns, + **kwargs, + ) + + name = "read-orc-" + tokenize(fs_token, path, columns, filters, **kwargs) + dsk = {} + N = 0 + for path, n in zip(paths, nstripes_per_file): + for stripe in ( + range(n) + if filters is None + else cudf.io.orc._filter_stripes(filters, path) + ): + dsk[(name, N)] = ( + _read_orc_stripe, + fs, + path, + stripe, + columns, + kwargs, + ) + N += 1 + + divisions = [None] * (len(dsk) + 1) + return dd.core.new_dd_object(dsk, name, meta, divisions) + + +def write_orc_partition(df, path, fs, filename, compression="snappy"): + full_path = fs.sep.join([path, filename]) + with fs.open(full_path, mode="wb") as out_file: + if not isinstance(out_file, IOBase): + out_file = BufferedWriter(out_file) + cudf.io.to_orc(df, out_file, compression=compression) + return full_path + + +def to_orc( + df, + path, + write_index=True, + storage_options=None, + compression="snappy", + compute=True, + **kwargs, +): + """ + Write a :class:`.DataFrame` to ORC file(s) (one file per partition). + + Parameters + ---------- + df : DataFrame + path : str or pathlib.Path + Destination directory for data. Prepend with protocol like ``s3://`` + or ``hdfs://`` for remote data. + write_index : boolean, optional + Whether or not to write the index. Defaults to True. + storage_options : None or dict + Further parameters to pass to the bytes backend. + compression : string or dict, optional + compute : bool, optional + If True (default) then the result is computed immediately. If + False then a :class:`~dask.delayed.Delayed` object is returned + for future computation. + + """ + + from dask import compute as dask_compute, delayed + + # TODO: Use upstream dask implementation once available + # (see: Dask Issue#5596) + + if hasattr(path, "name"): + path = stringify_path(path) + fs, _, _ = get_fs_token_paths( + path, mode="wb", storage_options=storage_options + ) + # Trim any protocol information from the path before forwarding + path = fs._strip_protocol(path) + + if write_index: + df = df.reset_index() + else: + # Not writing index - might as well drop it + df = df.reset_index(drop=True) + + fs.mkdirs(path, exist_ok=True) + + # Use i_offset and df.npartitions to define file-name list + filenames = ["part.%i.orc" % i for i in range(df.npartitions)] + + # write parts + dwrite = delayed(write_orc_partition) + parts = [ + dwrite(d, path, fs, filename, compression=compression) + for d, filename in zip(df.to_delayed(), filenames) + ] + + if compute: + return dask_compute(*parts) + + return delayed(list)(parts) diff --git a/python/dask_cudf/dask_cudf/_legacy/io/parquet.py b/python/dask_cudf/dask_cudf/_legacy/io/parquet.py new file mode 100644 index 00000000000..39ac6474958 --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/parquet.py @@ -0,0 +1,513 @@ +# Copyright (c) 2019-2024, NVIDIA CORPORATION. +import itertools +import warnings +from functools import partial +from io import BufferedWriter, BytesIO, IOBase + +import numpy as np +import pandas as pd +from pyarrow import dataset as pa_ds, parquet as pq + +from dask import dataframe as dd +from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine + +try: + from dask.dataframe.io.parquet import ( + create_metadata_file as create_metadata_file_dd, + ) +except ImportError: + create_metadata_file_dd = None + +import cudf +from cudf.core.column import CategoricalColumn, as_column +from cudf.io import write_to_dataset +from cudf.io.parquet import _apply_post_filters, _normalize_filters +from cudf.utils.dtypes import cudf_dtype_from_pa_type + + +class CudfEngine(ArrowDatasetEngine): + @classmethod + def _create_dd_meta(cls, dataset_info, **kwargs): + # Start with pandas-version of meta + meta_pd = super()._create_dd_meta(dataset_info, **kwargs) + + # Convert to cudf + # (drop unsupported timezone information) + for k, v in meta_pd.dtypes.items(): + if isinstance(v, pd.DatetimeTZDtype) and v.tz is not None: + meta_pd[k] = meta_pd[k].dt.tz_localize(None) + meta_cudf = cudf.from_pandas(meta_pd) + + # Re-set "object" dtypes to align with pa schema + kwargs = dataset_info.get("kwargs", {}) + set_object_dtypes_from_pa_schema( + meta_cudf, + kwargs.get("schema", None), + ) + + return meta_cudf + + @classmethod + def multi_support(cls): + # Assert that this class is CudfEngine + # and that multi-part reading is supported + return cls == CudfEngine + + @classmethod + def _read_paths( + cls, + paths, + fs, + columns=None, + row_groups=None, + filters=None, + partitions=None, + partitioning=None, + partition_keys=None, + open_file_options=None, + dataset_kwargs=None, + **kwargs, + ): + # Simplify row_groups if all None + if row_groups == [None for path in paths]: + row_groups = None + + # Make sure we read in the columns needed for row-wise + # filtering after IO. This means that one or more columns + # will be dropped almost immediately after IO. However, + # we do NEED these columns for accurate filtering. + filters = _normalize_filters(filters) + projected_columns = None + if columns and filters: + projected_columns = [c for c in columns if c is not None] + columns = sorted( + set(v[0] for v in itertools.chain.from_iterable(filters)) + | set(projected_columns) + ) + + dataset_kwargs = dataset_kwargs or {} + dataset_kwargs["partitioning"] = partitioning or "hive" + + # Use cudf to read in data + try: + df = cudf.read_parquet( + paths, + engine="cudf", + columns=columns, + row_groups=row_groups if row_groups else None, + dataset_kwargs=dataset_kwargs, + categorical_partitions=False, + filesystem=fs, + **kwargs, + ) + except RuntimeError as err: + # TODO: Remove try/except after null-schema issue is resolved + # (See: https://github.com/rapidsai/cudf/issues/12702) + if len(paths) > 1: + df = cudf.concat( + [ + cudf.read_parquet( + path, + engine="cudf", + columns=columns, + row_groups=row_groups[i] if row_groups else None, + dataset_kwargs=dataset_kwargs, + categorical_partitions=False, + filesystem=fs, + **kwargs, + ) + for i, path in enumerate(paths) + ] + ) + else: + raise err + + # Apply filters (if any are defined) + df = _apply_post_filters(df, filters) + + if projected_columns: + # Elements of `projected_columns` may now be in the index. + # We must filter these names from our projection + projected_columns = [ + col for col in projected_columns if col in df._column_names + ] + df = df[projected_columns] + + if partitions and partition_keys is None: + # Use `HivePartitioning` by default + ds = pa_ds.dataset( + paths, + filesystem=fs, + **dataset_kwargs, + ) + frag = next(ds.get_fragments()) + if frag: + # Extract hive-partition keys, and make sure they + # are ordered the same as they are in `partitions` + raw_keys = pa_ds._get_partition_keys(frag.partition_expression) + partition_keys = [ + (hive_part.name, raw_keys[hive_part.name]) + for hive_part in partitions + ] + + if partition_keys: + if partitions is None: + raise ValueError("Must pass partition sets") + + for i, (name, index2) in enumerate(partition_keys): + if len(partitions[i].keys): + # Build a categorical column from `codes` directly + # (since the category is often a larger dtype) + codes = as_column( + partitions[i].keys.get_loc(index2), + length=len(df), + ) + df[name] = CategoricalColumn( + data=None, + size=codes.size, + dtype=cudf.CategoricalDtype( + categories=partitions[i].keys, ordered=False + ), + offset=codes.offset, + children=(codes,), + ) + elif name not in df.columns: + # Add non-categorical partition column + df[name] = as_column(index2, length=len(df)) + + return df + + @classmethod + def read_partition( + cls, + fs, + pieces, + columns, + index, + categories=(), + partitions=(), + filters=None, + partitioning=None, + schema=None, + open_file_options=None, + **kwargs, + ): + if columns is not None: + columns = [c for c in columns] + if isinstance(index, list): + columns += index + + dataset_kwargs = kwargs.get("dataset", {}) + partitioning = partitioning or dataset_kwargs.get("partitioning", None) + if isinstance(partitioning, dict): + partitioning = pa_ds.partitioning(**partitioning) + + # Check if we are actually selecting any columns + read_columns = columns + if schema and columns: + ignored = set(schema.names) - set(columns) + if not ignored: + read_columns = None + + if not isinstance(pieces, list): + pieces = [pieces] + + # Extract supported kwargs from `kwargs` + read_kwargs = kwargs.get("read", {}) + read_kwargs.update(open_file_options or {}) + check_file_size = read_kwargs.pop("check_file_size", None) + + # Wrap reading logic in a `try` block so that we can + # inform the user that the `read_parquet` partition + # size is too large for the available memory + try: + # Assume multi-piece read + paths = [] + rgs = [] + last_partition_keys = None + dfs = [] + + for i, piece in enumerate(pieces): + (path, row_group, partition_keys) = piece + row_group = None if row_group == [None] else row_group + + # File-size check to help "protect" users from change + # to up-stream `split_row_groups` default. We only + # check the file size if this partition corresponds + # to a full file, and `check_file_size` is defined + if check_file_size and len(pieces) == 1 and row_group is None: + file_size = fs.size(path) + if file_size > check_file_size: + warnings.warn( + f"A large parquet file ({file_size}B) is being " + f"used to create a DataFrame partition in " + f"read_parquet. This may cause out of memory " + f"exceptions in operations downstream. See the " + f"notes on split_row_groups in the read_parquet " + f"documentation. Setting split_row_groups " + f"explicitly will silence this warning." + ) + + if i > 0 and partition_keys != last_partition_keys: + dfs.append( + cls._read_paths( + paths, + fs, + columns=read_columns, + row_groups=rgs if rgs else None, + filters=filters, + partitions=partitions, + partitioning=partitioning, + partition_keys=last_partition_keys, + dataset_kwargs=dataset_kwargs, + **read_kwargs, + ) + ) + paths = [] + rgs = [] + last_partition_keys = None + paths.append(path) + rgs.append( + [row_group] + if not isinstance(row_group, list) + and row_group is not None + else row_group + ) + last_partition_keys = partition_keys + + dfs.append( + cls._read_paths( + paths, + fs, + columns=read_columns, + row_groups=rgs if rgs else None, + filters=filters, + partitions=partitions, + partitioning=partitioning, + partition_keys=last_partition_keys, + dataset_kwargs=dataset_kwargs, + **read_kwargs, + ) + ) + df = cudf.concat(dfs) if len(dfs) > 1 else dfs[0] + + # Re-set "object" dtypes align with pa schema + set_object_dtypes_from_pa_schema(df, schema) + + if index and (index[0] in df.columns): + df = df.set_index(index[0]) + elif index is False and df.index.names != [None]: + # If index=False, we shouldn't have a named index + df.reset_index(inplace=True) + + except MemoryError as err: + raise MemoryError( + "Parquet data was larger than the available GPU memory!\n\n" + "See the notes on split_row_groups in the read_parquet " + "documentation.\n\n" + "Original Error: " + str(err) + ) + raise err + + return df + + @staticmethod + def write_partition( + df, + path, + fs, + filename, + partition_on, + return_metadata, + fmd=None, + compression="snappy", + index_cols=None, + **kwargs, + ): + preserve_index = False + if len(index_cols) and set(index_cols).issubset(set(df.columns)): + df.set_index(index_cols, drop=True, inplace=True) + preserve_index = True + if partition_on: + md = write_to_dataset( + df=df, + root_path=path, + compression=compression, + filename=filename, + partition_cols=partition_on, + fs=fs, + preserve_index=preserve_index, + return_metadata=return_metadata, + statistics=kwargs.get("statistics", "ROWGROUP"), + int96_timestamps=kwargs.get("int96_timestamps", False), + row_group_size_bytes=kwargs.get("row_group_size_bytes", None), + row_group_size_rows=kwargs.get("row_group_size_rows", None), + max_page_size_bytes=kwargs.get("max_page_size_bytes", None), + max_page_size_rows=kwargs.get("max_page_size_rows", None), + storage_options=kwargs.get("storage_options", None), + ) + else: + with fs.open(fs.sep.join([path, filename]), mode="wb") as out_file: + if not isinstance(out_file, IOBase): + out_file = BufferedWriter(out_file) + md = df.to_parquet( + path=out_file, + engine=kwargs.get("engine", "cudf"), + index=kwargs.get("index", None), + partition_cols=kwargs.get("partition_cols", None), + partition_file_name=kwargs.get( + "partition_file_name", None + ), + partition_offsets=kwargs.get("partition_offsets", None), + statistics=kwargs.get("statistics", "ROWGROUP"), + int96_timestamps=kwargs.get("int96_timestamps", False), + row_group_size_bytes=kwargs.get( + "row_group_size_bytes", None + ), + row_group_size_rows=kwargs.get( + "row_group_size_rows", None + ), + storage_options=kwargs.get("storage_options", None), + metadata_file_path=filename if return_metadata else None, + ) + # Return the schema needed to write the metadata + if return_metadata: + return [{"meta": md}] + else: + return [] + + @staticmethod + def write_metadata(parts, fmd, fs, path, append=False, **kwargs): + if parts: + # Aggregate metadata and write to _metadata file + metadata_path = fs.sep.join([path, "_metadata"]) + _meta = [] + if append and fmd is not None: + # Convert to bytes: + if isinstance(fmd, pq.FileMetaData): + with BytesIO() as myio: + fmd.write_metadata_file(myio) + myio.seek(0) + fmd = np.frombuffer(myio.read(), dtype="uint8") + _meta = [fmd] + _meta.extend([parts[i][0]["meta"] for i in range(len(parts))]) + _meta = ( + cudf.io.merge_parquet_filemetadata(_meta) + if len(_meta) > 1 + else _meta[0] + ) + with fs.open(metadata_path, "wb") as fil: + fil.write(memoryview(_meta)) + + @classmethod + def collect_file_metadata(cls, path, fs, file_path): + with fs.open(path, "rb") as f: + meta = pq.ParquetFile(f).metadata + if file_path: + meta.set_file_path(file_path) + with BytesIO() as myio: + meta.write_metadata_file(myio) + myio.seek(0) + meta = np.frombuffer(myio.read(), dtype="uint8") + return meta + + @classmethod + def aggregate_metadata(cls, meta_list, fs, out_path): + meta = ( + cudf.io.merge_parquet_filemetadata(meta_list) + if len(meta_list) > 1 + else meta_list[0] + ) + if out_path: + metadata_path = fs.sep.join([out_path, "_metadata"]) + with fs.open(metadata_path, "wb") as fil: + fil.write(memoryview(meta)) + return None + else: + return meta + + +def set_object_dtypes_from_pa_schema(df, schema): + # Simple utility to modify cudf DataFrame + # "object" dtypes to agree with a specific + # pyarrow schema. + if schema: + for col_name, col in df._data.items(): + if col_name is None: + # Pyarrow cannot handle `None` as a field name. + # However, this should be a simple range index that + # we can ignore anyway + continue + typ = cudf_dtype_from_pa_type(schema.field(col_name).type) + if ( + col_name in schema.names + and not isinstance(typ, (cudf.ListDtype, cudf.StructDtype)) + and isinstance(col, cudf.core.column.StringColumn) + ): + df._data[col_name] = col.astype(typ) + + +def read_parquet(path, columns=None, **kwargs): + """ + Read parquet files into a :class:`.DataFrame`. + + Calls :func:`dask.dataframe.read_parquet` with ``engine=CudfEngine`` + to coordinate the execution of :func:`cudf.read_parquet`, and to + ultimately create a :class:`.DataFrame` collection. + + See the :func:`dask.dataframe.read_parquet` documentation for + all available options. + + Examples + -------- + >>> from dask_cudf import read_parquet + >>> df = read_parquet("/path/to/dataset/") # doctest: +SKIP + + When dealing with one or more large parquet files having an + in-memory footprint >15% device memory, the ``split_row_groups`` + argument should be used to map Parquet **row-groups** to DataFrame + partitions (instead of **files** to partitions). For example, the + following code will map each row-group to a distinct partition: + + >>> df = read_parquet(..., split_row_groups=True) # doctest: +SKIP + + To map **multiple** row-groups to each partition, an integer can be + passed to ``split_row_groups`` to specify the **maximum** number of + row-groups allowed in each output partition: + + >>> df = read_parquet(..., split_row_groups=10) # doctest: +SKIP + + See Also + -------- + cudf.read_parquet + dask.dataframe.read_parquet + """ + if isinstance(columns, str): + columns = [columns] + + # Set "check_file_size" option to determine whether we + # should check the parquet-file size. This check is meant + # to "protect" users from `split_row_groups` default changes + check_file_size = kwargs.pop("check_file_size", 500_000_000) + if ( + check_file_size + and ("split_row_groups" not in kwargs) + and ("chunksize" not in kwargs) + ): + # User is not specifying `split_row_groups` or `chunksize`, + # so we should warn them if/when a file is ~>0.5GB on disk. + # They can set `split_row_groups` explicitly to silence/skip + # this check + if "read" not in kwargs: + kwargs["read"] = {} + kwargs["read"]["check_file_size"] = check_file_size + + return dd.read_parquet(path, columns=columns, engine=CudfEngine, **kwargs) + + +to_parquet = partial(dd.to_parquet, engine=CudfEngine) + +if create_metadata_file_dd is None: + create_metadata_file = create_metadata_file_dd +else: + create_metadata_file = partial(create_metadata_file_dd, engine=CudfEngine) diff --git a/python/dask_cudf/dask_cudf/_legacy/io/text.py b/python/dask_cudf/dask_cudf/_legacy/io/text.py new file mode 100644 index 00000000000..9cdb7c5220b --- /dev/null +++ b/python/dask_cudf/dask_cudf/_legacy/io/text.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION. + +import os +from glob import glob + +import dask.dataframe as dd +from dask.base import tokenize +from dask.utils import apply, parse_bytes + +import cudf + + +def read_text(path, chunksize="256 MiB", **kwargs): + if isinstance(chunksize, str): + chunksize = parse_bytes(chunksize) + + if isinstance(path, list): + filenames = path + elif isinstance(path, str): + filenames = sorted(glob(path)) + elif hasattr(path, "__fspath__"): + filenames = sorted(glob(path.__fspath__())) + else: + raise TypeError(f"Path type not understood:{type(path)}") + + if not filenames: + msg = f"A file in: {filenames} does not exist." + raise FileNotFoundError(msg) + + name = "read-text-" + tokenize(path, tokenize, **kwargs) + + if chunksize: + dsk = {} + i = 0 + for fn in filenames: + size = os.path.getsize(fn) + for start in range(0, size, chunksize): + kwargs1 = kwargs.copy() + kwargs1["byte_range"] = ( + start, + chunksize, + ) # specify which chunk of the file we care about + + dsk[(name, i)] = (apply, cudf.read_text, [fn], kwargs1) + i += 1 + else: + dsk = { + (name, i): (apply, cudf.read_text, [fn], kwargs) + for i, fn in enumerate(filenames) + } + + meta = cudf.Series([], dtype="O") + divisions = [None] * (len(dsk) + 1) + return dd.core.new_dd_object(dsk, name, meta, divisions) diff --git a/python/dask_cudf/dask_cudf/sorting.py b/python/dask_cudf/dask_cudf/_legacy/sorting.py similarity index 100% rename from python/dask_cudf/dask_cudf/sorting.py rename to python/dask_cudf/dask_cudf/_legacy/sorting.py diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index bead964a0ef..fb02e0ac772 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -46,7 +46,7 @@ from cudf.api.types import is_string_dtype from cudf.utils.performance_tracking import _dask_cudf_performance_tracking -from .core import DataFrame, Index, Series +from ._legacy.core import DataFrame, Index, Series get_parallel_type.register(cudf.DataFrame, lambda _: DataFrame) get_parallel_type.register(cudf.Series, lambda _: Series) @@ -574,7 +574,7 @@ class CudfBackendEntrypoint(DataFrameBackendEntrypoint): >>> with dask.config.set({"dataframe.backend": "cudf"}): ... ddf = dd.from_dict({"a": range(10)}) >>> type(ddf) - + """ @classmethod @@ -610,7 +610,7 @@ def from_dict( @staticmethod def read_parquet(*args, engine=None, **kwargs): - from dask_cudf.io.parquet import CudfEngine + from dask_cudf._legacy.io.parquet import CudfEngine _raise_unsupported_parquet_kwargs(**kwargs) return _default_backend( @@ -622,19 +622,19 @@ def read_parquet(*args, engine=None, **kwargs): @staticmethod def read_json(*args, **kwargs): - from dask_cudf.io.json import read_json + from dask_cudf._legacy.io.json import read_json return read_json(*args, **kwargs) @staticmethod def read_orc(*args, **kwargs): - from dask_cudf.io import read_orc + from dask_cudf._legacy.io import read_orc return read_orc(*args, **kwargs) @staticmethod def read_csv(*args, **kwargs): - from dask_cudf.io import read_csv + from dask_cudf._legacy.io import read_csv return read_csv(*args, **kwargs) @@ -674,7 +674,7 @@ class CudfDXBackendEntrypoint(DataFrameBackendEntrypoint): def to_backend(data, **kwargs): import dask_expr as dx - from dask_cudf.expr._expr import ToCudfBackend + from dask_cudf._expr.expr import ToCudfBackend return dx.new_collection(ToCudfBackend(data, kwargs)) @@ -710,7 +710,7 @@ def read_parquet(path, *args, filesystem="fsspec", engine=None, **kwargs): and filesystem.lower() == "fsspec" ): # Default "fsspec" filesystem - from dask_cudf.io.parquet import CudfEngine + from dask_cudf._legacy.io.parquet import CudfEngine _raise_unsupported_parquet_kwargs(**kwargs) return _default_backend( @@ -736,7 +736,7 @@ def read_parquet(path, *args, filesystem="fsspec", engine=None, **kwargs): from dask.core import flatten from dask.dataframe.utils import pyarrow_strings_enabled - from dask_cudf.expr._expr import CudfReadParquetPyarrowFS + from dask_cudf.io.parquet import CudfReadParquetPyarrowFS if args: raise ValueError(f"Unexpected positional arguments: {args}") @@ -862,7 +862,7 @@ def read_csv( @staticmethod def read_json(*args, **kwargs): - from dask_cudf.io.json import read_json as read_json_impl + from dask_cudf._legacy.io.json import read_json as read_json_impl return read_json_impl(*args, **kwargs) @@ -870,14 +870,7 @@ def read_json(*args, **kwargs): def read_orc(*args, **kwargs): from dask_expr import from_legacy_dataframe - from dask_cudf.io.orc import read_orc as legacy_read_orc + from dask_cudf._legacy.io.orc import read_orc as legacy_read_orc ddf = legacy_read_orc(*args, **kwargs) return from_legacy_dataframe(ddf) - - -# Import/register cudf-specific classes for dask-expr -try: - import dask_cudf.expr # noqa: F401 -except ImportError: - pass diff --git a/python/dask_cudf/dask_cudf/core.py b/python/dask_cudf/dask_cudf/core.py index 3181c8d69ec..7d6d5c05cbe 100644 --- a/python/dask_cudf/dask_cudf/core.py +++ b/python/dask_cudf/dask_cudf/core.py @@ -1,705 +1,25 @@ -# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -import math import textwrap -import warnings -import numpy as np -import pandas as pd -from tlz import partition_all - -from dask import dataframe as dd -from dask.base import normalize_token, tokenize -from dask.dataframe.core import ( - Scalar, - handle_out, - make_meta as dask_make_meta, - map_partitions, -) -from dask.dataframe.utils import raise_on_meta_error -from dask.highlevelgraph import HighLevelGraph -from dask.utils import M, OperatorMethodMixin, apply, derived_from, funcname +import dask.dataframe as dd +from dask.tokenize import tokenize import cudf -from cudf import _lib as libcudf from cudf.utils.performance_tracking import _dask_cudf_performance_tracking -from dask_cudf import sorting -from dask_cudf.accessors import ListMethods, StructMethods -from dask_cudf.sorting import _deprecate_shuffle_kwarg, _get_shuffle_method - - -class _Frame(dd.core._Frame, OperatorMethodMixin): - """Superclass for DataFrame and Series - - Parameters - ---------- - dsk : dict - The dask graph to compute this DataFrame - name : str - The key prefix that specifies which keys in the dask comprise this - particular DataFrame / Series - meta : cudf.DataFrame, cudf.Series, or cudf.Index - An empty cudf object with names, dtypes, and indices matching the - expected output. - divisions : tuple of index values - Values along which we partition our blocks on the index - """ - - def _is_partition_type(self, meta): - return isinstance(meta, self._partition_type) - - def __repr__(self): - s = "" - return s % (type(self).__name__, len(self.dask), self.npartitions) - - @_dask_cudf_performance_tracking - def to_dask_dataframe(self, **kwargs): - """Create a dask.dataframe object from a dask_cudf object - - WARNING: This API is deprecated, and may not work properly - when query-planning is active. Please use `*.to_backend("pandas")` - to convert the underlying data to pandas. - """ - - warnings.warn( - "The `to_dask_dataframe` API is now deprecated. " - "Please use `*.to_backend('pandas')` instead.", - FutureWarning, - ) - - return self.to_backend("pandas", **kwargs) - - -concat = dd.concat - - -normalize_token.register(_Frame, lambda a: a._name) - - -class DataFrame(_Frame, dd.core.DataFrame): - """ - A distributed Dask DataFrame where the backing dataframe is a - :class:`cuDF DataFrame `. - - Typically you would not construct this object directly, but rather - use one of Dask-cuDF's IO routines. - - Most operations on :doc:`Dask DataFrames ` are - supported, with many of the same caveats. - - """ - - _partition_type = cudf.DataFrame - - @_dask_cudf_performance_tracking - def _assign_column(self, k, v): - def assigner(df, k, v): - out = df.copy() - out[k] = v - return out - - meta = assigner(self._meta, k, dask_make_meta(v)) - return self.map_partitions(assigner, k, v, meta=meta) - - @_dask_cudf_performance_tracking - def apply_rows(self, func, incols, outcols, kwargs=None, cache_key=None): - import uuid - - if kwargs is None: - kwargs = {} - - if cache_key is None: - cache_key = uuid.uuid4() - - def do_apply_rows(df, func, incols, outcols, kwargs): - return df.apply_rows( - func, incols, outcols, kwargs, cache_key=cache_key - ) - - meta = do_apply_rows(self._meta, func, incols, outcols, kwargs) - return self.map_partitions( - do_apply_rows, func, incols, outcols, kwargs, meta=meta - ) - - @_deprecate_shuffle_kwarg - @_dask_cudf_performance_tracking - def merge(self, other, shuffle_method=None, **kwargs): - on = kwargs.pop("on", None) - if isinstance(on, tuple): - on = list(on) - return super().merge( - other, - on=on, - shuffle_method=_get_shuffle_method(shuffle_method), - **kwargs, - ) - - @_deprecate_shuffle_kwarg - @_dask_cudf_performance_tracking - def join(self, other, shuffle_method=None, **kwargs): - # CuDF doesn't support "right" join yet - how = kwargs.pop("how", "left") - if how == "right": - return other.join(other=self, how="left", **kwargs) - - on = kwargs.pop("on", None) - if isinstance(on, tuple): - on = list(on) - return super().join( - other, - how=how, - on=on, - shuffle_method=_get_shuffle_method(shuffle_method), - **kwargs, - ) - - @_deprecate_shuffle_kwarg - @_dask_cudf_performance_tracking - def set_index( - self, - other, - sorted=False, - divisions=None, - shuffle_method=None, - **kwargs, - ): - pre_sorted = sorted - del sorted - - if divisions == "quantile": - warnings.warn( - "Using divisions='quantile' is now deprecated. " - "Please raise an issue on github if you believe " - "this feature is necessary.", - FutureWarning, - ) - - if ( - divisions == "quantile" - or isinstance(divisions, (cudf.DataFrame, cudf.Series)) - or ( - isinstance(other, str) - and cudf.api.types.is_string_dtype(self[other].dtype) - ) - ): - # Let upstream-dask handle "pre-sorted" case - if pre_sorted: - return dd.shuffle.set_sorted_index( - self, other, divisions=divisions, **kwargs - ) - - by = other - if not isinstance(other, list): - by = [by] - if len(by) > 1: - raise ValueError("Dask does not support MultiIndex (yet).") - if divisions == "quantile": - divisions = None - - # Use dask_cudf's sort_values - df = self.sort_values( - by, - max_branch=kwargs.get("max_branch", None), - divisions=divisions, - set_divisions=True, - ignore_index=True, - shuffle_method=shuffle_method, - ) - - # Ignore divisions if its a dataframe - if isinstance(divisions, cudf.DataFrame): - divisions = None - - # Set index and repartition - df2 = df.map_partitions( - sorting.set_index_post, - index_name=other, - drop=kwargs.get("drop", True), - column_dtype=df.columns.dtype, - ) - npartitions = kwargs.get("npartitions", self.npartitions) - partition_size = kwargs.get("partition_size", None) - if partition_size: - return df2.repartition(partition_size=partition_size) - if not divisions and df2.npartitions != npartitions: - return df2.repartition(npartitions=npartitions) - if divisions and df2.npartitions != len(divisions) - 1: - return df2.repartition(divisions=divisions) - return df2 - - return super().set_index( - other, - sorted=pre_sorted, - shuffle_method=_get_shuffle_method(shuffle_method), - divisions=divisions, - **kwargs, - ) - - @_deprecate_shuffle_kwarg - @_dask_cudf_performance_tracking - def sort_values( - self, - by, - ignore_index=False, - max_branch=None, - divisions=None, - set_divisions=False, - ascending=True, - na_position="last", - sort_function=None, - sort_function_kwargs=None, - shuffle_method=None, - **kwargs, - ): - if kwargs: - raise ValueError( - f"Unsupported input arguments passed : {list(kwargs.keys())}" - ) - - df = sorting.sort_values( - self, - by, - max_branch=max_branch, - divisions=divisions, - set_divisions=set_divisions, - ignore_index=ignore_index, - ascending=ascending, - na_position=na_position, - shuffle_method=shuffle_method, - sort_function=sort_function, - sort_function_kwargs=sort_function_kwargs, - ) - - if ignore_index: - return df.reset_index(drop=True) - return df - - @_dask_cudf_performance_tracking - def to_parquet(self, path, *args, **kwargs): - """Calls dask.dataframe.io.to_parquet with CudfEngine backend""" - from dask_cudf.io import to_parquet - - return to_parquet(self, path, *args, **kwargs) - - @_dask_cudf_performance_tracking - def to_orc(self, path, **kwargs): - """Calls dask_cudf.io.to_orc""" - from dask_cudf.io import to_orc - - return to_orc(self, path, **kwargs) - - @derived_from(pd.DataFrame) - @_dask_cudf_performance_tracking - def var( - self, - axis=None, - skipna=True, - ddof=1, - split_every=False, - dtype=None, - out=None, - naive=False, - numeric_only=False, - ): - axis = self._validate_axis(axis) - meta = self._meta_nonempty.var( - axis=axis, skipna=skipna, numeric_only=numeric_only - ) - if axis == 1: - result = map_partitions( - M.var, - self, - meta=meta, - token=self._token_prefix + "var", - axis=axis, - skipna=skipna, - ddof=ddof, - numeric_only=numeric_only, - ) - return handle_out(out, result) - elif naive: - return _naive_var(self, meta, skipna, ddof, split_every, out) - else: - return _parallel_var(self, meta, skipna, split_every, out) - - @_deprecate_shuffle_kwarg - @_dask_cudf_performance_tracking - def shuffle(self, *args, shuffle_method=None, **kwargs): - """Wraps dask.dataframe DataFrame.shuffle method""" - return super().shuffle( - *args, shuffle_method=_get_shuffle_method(shuffle_method), **kwargs - ) - - @_dask_cudf_performance_tracking - def groupby(self, by=None, **kwargs): - from .groupby import CudfDataFrameGroupBy - - return CudfDataFrameGroupBy(self, by=by, **kwargs) - - -@_dask_cudf_performance_tracking -def sum_of_squares(x): - x = x.astype("f8")._column - outcol = libcudf.reduce.reduce("sum_of_squares", x) - return cudf.Series._from_column(outcol) - - -@_dask_cudf_performance_tracking -def var_aggregate(x2, x, n, ddof): - try: - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - result = (x2 / n) - (x / n) ** 2 - if ddof != 0: - result = result * n / (n - ddof) - return result - except ZeroDivisionError: - return np.float64(np.nan) - - -@_dask_cudf_performance_tracking -def nlargest_agg(x, **kwargs): - return cudf.concat(x).nlargest(**kwargs) - - -@_dask_cudf_performance_tracking -def nsmallest_agg(x, **kwargs): - return cudf.concat(x).nsmallest(**kwargs) - - -class Series(_Frame, dd.core.Series): - _partition_type = cudf.Series - - @_dask_cudf_performance_tracking - def count(self, split_every=False): - return reduction( - [self], - chunk=M.count, - aggregate=np.sum, - split_every=split_every, - meta="i8", - ) - - @_dask_cudf_performance_tracking - def mean(self, split_every=False): - sum = self.sum(split_every=split_every) - n = self.count(split_every=split_every) - return sum / n - - @derived_from(pd.DataFrame) - @_dask_cudf_performance_tracking - def var( - self, - axis=None, - skipna=True, - ddof=1, - split_every=False, - dtype=None, - out=None, - naive=False, - ): - axis = self._validate_axis(axis) - meta = self._meta_nonempty.var(axis=axis, skipna=skipna) - if axis == 1: - result = map_partitions( - M.var, - self, - meta=meta, - token=self._token_prefix + "var", - axis=axis, - skipna=skipna, - ddof=ddof, - ) - return handle_out(out, result) - elif naive: - return _naive_var(self, meta, skipna, ddof, split_every, out) - else: - return _parallel_var(self, meta, skipna, split_every, out) - - @_dask_cudf_performance_tracking - def groupby(self, *args, **kwargs): - from .groupby import CudfSeriesGroupBy - - return CudfSeriesGroupBy(self, *args, **kwargs) - - @property # type: ignore - @_dask_cudf_performance_tracking - def list(self): - return ListMethods(self) - - @property # type: ignore - @_dask_cudf_performance_tracking - def struct(self): - return StructMethods(self) - - -class Index(Series, dd.core.Index): - _partition_type = cudf.Index # type: ignore - - -@_dask_cudf_performance_tracking -def _naive_var(ddf, meta, skipna, ddof, split_every, out): - num = ddf._get_numeric_data() - x = 1.0 * num.sum(skipna=skipna, split_every=split_every) - x2 = 1.0 * (num**2).sum(skipna=skipna, split_every=split_every) - n = num.count(split_every=split_every) - name = ddf._token_prefix + "var" - result = map_partitions( - var_aggregate, x2, x, n, token=name, meta=meta, ddof=ddof - ) - if isinstance(ddf, DataFrame): - result.divisions = (min(ddf.columns), max(ddf.columns)) - return handle_out(out, result) - - -@_dask_cudf_performance_tracking -def _parallel_var(ddf, meta, skipna, split_every, out): - def _local_var(x, skipna): - if skipna: - n = x.count() - avg = x.mean(skipna=skipna) - else: - # Not skipping nulls, so might as well - # avoid the full `count` operation - n = len(x) - avg = x.sum(skipna=skipna) / n - m2 = ((x - avg) ** 2).sum(skipna=skipna) - return n, avg, m2 - - def _aggregate_var(parts): - n, avg, m2 = parts[0] - for i in range(1, len(parts)): - n_a, avg_a, m2_a = n, avg, m2 - n_b, avg_b, m2_b = parts[i] - n = n_a + n_b - avg = (n_a * avg_a + n_b * avg_b) / n - delta = avg_b - avg_a - m2 = m2_a + m2_b + delta**2 * n_a * n_b / n - return n, avg, m2 - - def _finalize_var(vals): - n, _, m2 = vals - return m2 / (n - 1) - - # Build graph - nparts = ddf.npartitions - if not split_every: - split_every = nparts - name = "var-" + tokenize(skipna, split_every, out) - local_name = "local-" + name - num = ddf._get_numeric_data() - dsk = { - (local_name, n, 0): (_local_var, (num._name, n), skipna) - for n in range(nparts) - } - - # Use reduction tree - widths = [nparts] - while nparts > 1: - nparts = math.ceil(nparts / split_every) - widths.append(nparts) - height = len(widths) - for depth in range(1, height): - for group in range(widths[depth]): - p_max = widths[depth - 1] - lstart = split_every * group - lstop = min(lstart + split_every, p_max) - node_list = [ - (local_name, p, depth - 1) for p in range(lstart, lstop) - ] - dsk[(local_name, group, depth)] = (_aggregate_var, node_list) - if height == 1: - group = depth = 0 - dsk[(name, 0)] = (_finalize_var, (local_name, group, depth)) - - graph = HighLevelGraph.from_collections(name, dsk, dependencies=[num, ddf]) - result = dd.core.new_dd_object(graph, name, meta, (None, None)) - if isinstance(ddf, DataFrame): - result.divisions = (min(ddf.columns), max(ddf.columns)) - return handle_out(out, result) - - -@_dask_cudf_performance_tracking -def _extract_meta(x): - """ - Extract internal cache data (``_meta``) from dask_cudf objects - """ - if isinstance(x, (Scalar, _Frame)): - return x._meta - elif isinstance(x, list): - return [_extract_meta(_x) for _x in x] - elif isinstance(x, tuple): - return tuple(_extract_meta(_x) for _x in x) - elif isinstance(x, dict): - return {k: _extract_meta(v) for k, v in x.items()} - return x - - -@_dask_cudf_performance_tracking -def _emulate(func, *args, **kwargs): - """ - Apply a function using args / kwargs. If arguments contain dd.DataFrame / - dd.Series, using internal cache (``_meta``) for calculation - """ - with raise_on_meta_error(funcname(func)): - return func(*_extract_meta(args), **_extract_meta(kwargs)) - - -@_dask_cudf_performance_tracking -def align_partitions(args): - """Align partitions between dask_cudf objects. - - Note that if all divisions are unknown, but have equal npartitions, then - they will be passed through unchanged. - """ - dfs = [df for df in args if isinstance(df, _Frame)] - if not dfs: - return args - - divisions = dfs[0].divisions - if not all(df.divisions == divisions for df in dfs): - raise NotImplementedError("Aligning mismatched partitions") - return args - - -@_dask_cudf_performance_tracking -def reduction( - args, - chunk=None, - aggregate=None, - combine=None, - meta=None, - token=None, - chunk_kwargs=None, - aggregate_kwargs=None, - combine_kwargs=None, - split_every=None, - **kwargs, -): - """Generic tree reduction operation. - - Parameters - ---------- - args : - Positional arguments for the `chunk` function. All `dask.dataframe` - objects should be partitioned and indexed equivalently. - chunk : function [block-per-arg] -> block - Function to operate on each block of data - aggregate : function list-of-blocks -> block - Function to operate on the list of results of chunk - combine : function list-of-blocks -> block, optional - Function to operate on intermediate lists of results of chunk - in a tree-reduction. If not provided, defaults to aggregate. - $META - token : str, optional - The name to use for the output keys. - chunk_kwargs : dict, optional - Keywords for the chunk function only. - aggregate_kwargs : dict, optional - Keywords for the aggregate function only. - combine_kwargs : dict, optional - Keywords for the combine function only. - split_every : int, optional - Group partitions into groups of this size while performing a - tree-reduction. If set to False, no tree-reduction will be used, - and all intermediates will be concatenated and passed to ``aggregate``. - Default is 8. - kwargs : - All remaining keywords will be passed to ``chunk``, ``aggregate``, and - ``combine``. - """ - if chunk_kwargs is None: - chunk_kwargs = dict() - if aggregate_kwargs is None: - aggregate_kwargs = dict() - chunk_kwargs.update(kwargs) - aggregate_kwargs.update(kwargs) - - if combine is None: - if combine_kwargs: - raise ValueError("`combine_kwargs` provided with no `combine`") - combine = aggregate - combine_kwargs = aggregate_kwargs - else: - if combine_kwargs is None: - combine_kwargs = dict() - combine_kwargs.update(kwargs) - - if not isinstance(args, (tuple, list)): - args = [args] - - npartitions = {arg.npartitions for arg in args if isinstance(arg, _Frame)} - if len(npartitions) > 1: - raise ValueError("All arguments must have same number of partitions") - npartitions = npartitions.pop() - - if split_every is None: - split_every = 8 - elif split_every is False: - split_every = npartitions - elif split_every < 2 or not isinstance(split_every, int): - raise ValueError("split_every must be an integer >= 2") - - token_key = tokenize( - token or (chunk, aggregate), - meta, - args, - chunk_kwargs, - aggregate_kwargs, - combine_kwargs, - split_every, +# This module provides backward compatibility for legacy import patterns. +if dd.DASK_EXPR_ENABLED: + from dask_cudf._expr.collection import ( # noqa: E402 + DataFrame, + Index, + Series, ) +else: + from dask_cudf._legacy.core import DataFrame, Index, Series # noqa: F401 - # Chunk - a = f"{token or funcname(chunk)}-chunk-{token_key}" - if len(args) == 1 and isinstance(args[0], _Frame) and not chunk_kwargs: - dsk = { - (a, 0, i): (chunk, key) - for i, key in enumerate(args[0].__dask_keys__()) - } - else: - dsk = { - (a, 0, i): ( - apply, - chunk, - [(x._name, i) if isinstance(x, _Frame) else x for x in args], - chunk_kwargs, - ) - for i in range(args[0].npartitions) - } - # Combine - b = f"{token or funcname(combine)}-combine-{token_key}" - k = npartitions - depth = 0 - while k > split_every: - for part_i, inds in enumerate(partition_all(split_every, range(k))): - conc = (list, [(a, depth, i) for i in inds]) - dsk[(b, depth + 1, part_i)] = ( - (apply, combine, [conc], combine_kwargs) - if combine_kwargs - else (combine, conc) - ) - k = part_i + 1 - a = b - depth += 1 - - # Aggregate - b = f"{token or funcname(aggregate)}-agg-{token_key}" - conc = (list, [(a, depth, i) for i in range(k)]) - if aggregate_kwargs: - dsk[(b, 0)] = (apply, aggregate, [conc], aggregate_kwargs) - else: - dsk[(b, 0)] = (aggregate, conc) - - if meta is None: - meta_chunk = _emulate(apply, chunk, args, chunk_kwargs) - meta = _emulate(apply, aggregate, [[meta_chunk]], aggregate_kwargs) - meta = dask_make_meta(meta) - - graph = HighLevelGraph.from_collections(b, dsk, dependencies=args) - return dd.core.new_dd_object(graph, b, meta, (None, None)) +concat = dd.concat # noqa: F401 @_dask_cudf_performance_tracking @@ -744,59 +64,3 @@ def from_cudf(data, npartitions=None, chunksize=None, sort=True, name=None): # since dask-expr does not provide a docstring for from_pandas. + textwrap.dedent(dd.from_pandas.__doc__ or "") ) - - -@_dask_cudf_performance_tracking -def from_dask_dataframe(df): - """ - Convert a Dask :class:`dask.dataframe.DataFrame` to a Dask-cuDF - one. - - WARNING: This API is deprecated, and may not work properly - when query-planning is active. Please use `*.to_backend("cudf")` - to convert the underlying data to cudf. - - Parameters - ---------- - df : dask.dataframe.DataFrame - The Dask dataframe to convert - - Returns - ------- - dask_cudf.DataFrame : A new Dask collection backed by cuDF objects - """ - - warnings.warn( - "The `from_dask_dataframe` API is now deprecated. " - "Please use `*.to_backend('cudf')` instead.", - FutureWarning, - ) - - return df.to_backend("cudf") - - -for name in ( - "add", - "sub", - "mul", - "truediv", - "floordiv", - "mod", - "pow", - "radd", - "rsub", - "rmul", - "rtruediv", - "rfloordiv", - "rmod", - "rpow", -): - meth = getattr(cudf.DataFrame, name) - DataFrame._bind_operator_method(name, meth, original=cudf.Series) - - meth = getattr(cudf.Series, name) - Series._bind_operator_method(name, meth, original=cudf.Series) - -for name in ("lt", "gt", "le", "ge", "ne", "eq"): - meth = getattr(cudf.Series, name) - Series._bind_comparison_method(name, meth, original=cudf.Series) diff --git a/python/dask_cudf/dask_cudf/expr/__init__.py b/python/dask_cudf/dask_cudf/expr/__init__.py deleted file mode 100644 index 6dadadd5263..00000000000 --- a/python/dask_cudf/dask_cudf/expr/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. - -from dask import config - -# Check if dask-dataframe is using dask-expr. -# For dask>=2024.3.0, a null value will default to True -QUERY_PLANNING_ON = config.get("dataframe.query-planning", None) is not False - -# Register custom expressions and collections -if QUERY_PLANNING_ON: - # Broadly avoid "p2p" and "disk" defaults for now - config.set({"dataframe.shuffle.method": "tasks"}) - - try: - import dask_cudf.expr._collection # noqa: F401 - import dask_cudf.expr._expr # noqa: F401 - - except ImportError as err: - # Dask *should* raise an error before this. - # However, we can still raise here to be certain. - raise RuntimeError( - "Failed to register the 'cudf' backend for dask-expr." - " Please make sure you have dask-expr installed.\n" - f"Error Message: {err}" - ) diff --git a/python/dask_cudf/dask_cudf/expr/_expr.py b/python/dask_cudf/dask_cudf/expr/_expr.py deleted file mode 100644 index c7cf66fbffd..00000000000 --- a/python/dask_cudf/dask_cudf/expr/_expr.py +++ /dev/null @@ -1,511 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. -import functools - -import dask_expr._shuffle as _shuffle_module -import pandas as pd -from dask_expr import new_collection -from dask_expr._cumulative import CumulativeBlockwise -from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns -from dask_expr._groupby import ( - DecomposableGroupbyAggregation, - GroupbyAggregation, -) -from dask_expr._reductions import Reduction, Var -from dask_expr.io.io import FusedParquetIO -from dask_expr.io.parquet import FragmentWrapper, ReadParquetPyarrowFS - -from dask.dataframe.core import ( - _concat, - is_dataframe_like, - make_meta, - meta_nonempty, -) -from dask.dataframe.dispatch import is_categorical_dtype -from dask.typing import no_default - -import cudf - -## -## Custom expressions -## - - -def _get_spec_info(gb): - if isinstance(gb.arg, (dict, list)): - aggs = gb.arg.copy() - else: - aggs = gb.arg - - if gb._slice and not isinstance(aggs, dict): - aggs = {gb._slice: aggs} - - gb_cols = gb._by_columns - if isinstance(gb_cols, str): - gb_cols = [gb_cols] - columns = [c for c in gb.frame.columns if c not in gb_cols] - if not isinstance(aggs, dict): - aggs = {col: aggs for col in columns} - - # Assert if our output will have a MultiIndex; this will be the case if - # any value in the `aggs` dict is not a string (i.e. multiple/named - # aggregations per column) - str_cols_out = True - aggs_renames = {} - for col in aggs: - if isinstance(aggs[col], str) or callable(aggs[col]): - aggs[col] = [aggs[col]] - elif isinstance(aggs[col], dict): - str_cols_out = False - col_aggs = [] - for k, v in aggs[col].items(): - aggs_renames[col, v] = k - col_aggs.append(v) - aggs[col] = col_aggs - else: - str_cols_out = False - if col in gb_cols: - columns.append(col) - - return { - "aggs": aggs, - "columns": columns, - "str_cols_out": str_cols_out, - "aggs_renames": aggs_renames, - } - - -def _get_meta(gb): - spec_info = gb.spec_info - gb_cols = gb._by_columns - aggs = spec_info["aggs"].copy() - aggs_renames = spec_info["aggs_renames"] - if spec_info["str_cols_out"]: - # Metadata should use `str` for dict values if that is - # what the user originally specified (column names will - # be str, rather than tuples). - for col in aggs: - aggs[col] = aggs[col][0] - _meta = gb.frame._meta.groupby(gb_cols).agg(aggs) - if aggs_renames: - col_array = [] - agg_array = [] - for col, agg in _meta.columns: - col_array.append(col) - agg_array.append(aggs_renames.get((col, agg), agg)) - _meta.columns = pd.MultiIndex.from_arrays([col_array, agg_array]) - return _meta - - -class DecomposableCudfGroupbyAgg(DecomposableGroupbyAggregation): - sep = "___" - - @functools.cached_property - def spec_info(self): - return _get_spec_info(self) - - @functools.cached_property - def _meta(self): - return _get_meta(self) - - @property - def shuffle_by_index(self): - return False # We always group by column(s) - - @classmethod - def chunk(cls, df, *by, **kwargs): - from dask_cudf.groupby import _groupby_partition_agg - - return _groupby_partition_agg(df, **kwargs) - - @classmethod - def combine(cls, inputs, **kwargs): - from dask_cudf.groupby import _tree_node_agg - - return _tree_node_agg(_concat(inputs), **kwargs) - - @classmethod - def aggregate(cls, inputs, **kwargs): - from dask_cudf.groupby import _finalize_gb_agg - - return _finalize_gb_agg(_concat(inputs), **kwargs) - - @property - def chunk_kwargs(self) -> dict: - dropna = True if self.dropna is None else self.dropna - return { - "gb_cols": self._by_columns, - "aggs": self.spec_info["aggs"], - "columns": self.spec_info["columns"], - "dropna": dropna, - "sort": self.sort, - "sep": self.sep, - } - - @property - def combine_kwargs(self) -> dict: - dropna = True if self.dropna is None else self.dropna - return { - "gb_cols": self._by_columns, - "dropna": dropna, - "sort": self.sort, - "sep": self.sep, - } - - @property - def aggregate_kwargs(self) -> dict: - dropna = True if self.dropna is None else self.dropna - final_columns = self._slice or self._meta.columns - return { - "gb_cols": self._by_columns, - "aggs": self.spec_info["aggs"], - "columns": self.spec_info["columns"], - "final_columns": final_columns, - "as_index": True, - "dropna": dropna, - "sort": self.sort, - "sep": self.sep, - "str_cols_out": self.spec_info["str_cols_out"], - "aggs_renames": self.spec_info["aggs_renames"], - } - - -class CudfGroupbyAgg(GroupbyAggregation): - @functools.cached_property - def spec_info(self): - return _get_spec_info(self) - - @functools.cached_property - def _meta(self): - return _get_meta(self) - - def _lower(self): - return DecomposableCudfGroupbyAgg( - self.frame, - self.arg, - self.observed, - self.dropna, - self.split_every, - self.split_out, - self.sort, - self.shuffle_method, - self._slice, - *self.by, - ) - - -def _maybe_get_custom_expr( - gb, - aggs, - split_every=None, - split_out=None, - shuffle_method=None, - **kwargs, -): - from dask_cudf.groupby import ( - OPTIMIZED_AGGS, - _aggs_optimized, - _redirect_aggs, - ) - - if kwargs: - # Unsupported key-word arguments - return None - - if not hasattr(gb.obj._meta, "to_pandas"): - # Not cuDF-backed data - return None - - _aggs = _redirect_aggs(aggs) - if not _aggs_optimized(_aggs, OPTIMIZED_AGGS): - # One or more aggregations are unsupported - return None - - return CudfGroupbyAgg( - gb.obj.expr, - _aggs, - gb.observed, - gb.dropna, - split_every, - split_out, - gb.sort, - shuffle_method, - gb._slice, - *gb.by, - ) - - -class CudfFusedParquetIO(FusedParquetIO): - @staticmethod - def _load_multiple_files( - frag_filters, - columns, - schema, - *to_pandas_args, - ): - import pyarrow as pa - - from dask.base import apply, tokenize - from dask.threaded import get - - token = tokenize(frag_filters, columns, schema) - name = f"pq-file-{token}" - dsk = { - (name, i): ( - CudfReadParquetPyarrowFS._fragment_to_table, - frag, - filter, - columns, - schema, - ) - for i, (frag, filter) in enumerate(frag_filters) - } - dsk[name] = ( - apply, - pa.concat_tables, - [list(dsk.keys())], - {"promote_options": "permissive"}, - ) - return CudfReadParquetPyarrowFS._table_to_pandas( - get(dsk, name), - *to_pandas_args, - ) - - -class CudfReadParquetPyarrowFS(ReadParquetPyarrowFS): - @functools.cached_property - def _dataset_info(self): - from dask_cudf.io.parquet import set_object_dtypes_from_pa_schema - - dataset_info = super()._dataset_info - meta_pd = dataset_info["base_meta"] - if isinstance(meta_pd, cudf.DataFrame): - return dataset_info - - # Convert to cudf - # (drop unsupported timezone information) - for k, v in meta_pd.dtypes.items(): - if isinstance(v, pd.DatetimeTZDtype) and v.tz is not None: - meta_pd[k] = meta_pd[k].dt.tz_localize(None) - meta_cudf = cudf.from_pandas(meta_pd) - - # Re-set "object" dtypes to align with pa schema - kwargs = dataset_info.get("kwargs", {}) - set_object_dtypes_from_pa_schema( - meta_cudf, - kwargs.get("schema", None), - ) - - dataset_info["base_meta"] = meta_cudf - self.operands[type(self)._parameters.index("_dataset_info_cache")] = ( - dataset_info - ) - return dataset_info - - @staticmethod - def _table_to_pandas(table, index_name): - df = cudf.DataFrame.from_arrow(table) - if index_name is not None: - df = df.set_index(index_name) - return df - - def _filtered_task(self, index: int): - columns = self.columns.copy() - index_name = self.index.name - if self.index is not None: - index_name = self.index.name - schema = self._dataset_info["schema"].remove_metadata() - if index_name: - if columns is None: - columns = list(schema.names) - columns.append(index_name) - return ( - self._table_to_pandas, - ( - self._fragment_to_table, - FragmentWrapper(self.fragments[index], filesystem=self.fs), - self.filters, - columns, - schema, - ), - index_name, - ) - - def _tune_up(self, parent): - if self._fusion_compression_factor >= 1: - return - if isinstance(parent, CudfFusedParquetIO): - return - return parent.substitute(self, CudfFusedParquetIO(self)) - - -class RenameAxisCudf(RenameAxis): - # TODO: Remove this after rename_axis is supported in cudf - # (See: https://github.com/rapidsai/cudf/issues/16895) - @staticmethod - def operation(df, index=no_default, **kwargs): - if index != no_default: - df.index.name = index - return df - raise NotImplementedError( - "Only `index` is supported for the cudf backend" - ) - - -class ToCudfBackend(Elemwise): - # TODO: Inherit from ToBackend when rapids-dask-dependency - # is pinned to dask>=2024.8.1 - _parameters = ["frame", "options"] - _projection_passthrough = True - _filter_passthrough = True - _preserves_partitioning_information = True - - @staticmethod - def operation(df, options): - from dask_cudf.backends import to_cudf_dispatch - - return to_cudf_dispatch(df, **options) - - def _simplify_down(self): - if isinstance( - self.frame._meta, (cudf.DataFrame, cudf.Series, cudf.Index) - ): - # We already have cudf data - return self.frame - - -## -## Custom expression patching -## - - -# This can be removed after cudf#15176 is addressed. -# See: https://github.com/rapidsai/cudf/issues/15176 -class PatchCumulativeBlockwise(CumulativeBlockwise): - @property - def _args(self) -> list: - return self.operands[:1] - - @property - def _kwargs(self) -> dict: - # Must pass axis and skipna as kwargs in cudf - return {"axis": self.axis, "skipna": self.skipna} - - -CumulativeBlockwise._args = PatchCumulativeBlockwise._args -CumulativeBlockwise._kwargs = PatchCumulativeBlockwise._kwargs - - -# The upstream Var code uses `Series.values`, and relies on numpy -# for most of the logic. Unfortunately, cudf -> cupy conversion -# is not supported for data containing null values. Therefore, -# we must implement our own version of Var for now. This logic -# is mostly copied from dask-cudf. - - -class VarCudf(Reduction): - # Uses the parallel version of Welford's online algorithm (Chan '79) - # (http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf) - _parameters = ["frame", "skipna", "ddof", "numeric_only", "split_every"] - _defaults = { - "skipna": True, - "ddof": 1, - "numeric_only": False, - "split_every": False, - } - - @functools.cached_property - def _meta(self): - return make_meta( - meta_nonempty(self.frame._meta).var( - skipna=self.skipna, numeric_only=self.numeric_only - ) - ) - - @property - def chunk_kwargs(self): - return dict(skipna=self.skipna, numeric_only=self.numeric_only) - - @property - def combine_kwargs(self): - return {} - - @property - def aggregate_kwargs(self): - return dict(ddof=self.ddof) - - @classmethod - def reduction_chunk(cls, x, skipna=True, numeric_only=False): - kwargs = {"numeric_only": numeric_only} if is_dataframe_like(x) else {} - if skipna or numeric_only: - n = x.count(**kwargs) - kwargs["skipna"] = skipna - avg = x.mean(**kwargs) - else: - # Not skipping nulls, so might as well - # avoid the full `count` operation - n = len(x) - kwargs["skipna"] = skipna - avg = x.sum(**kwargs) / n - if numeric_only: - # Workaround for cudf bug - # (see: https://github.com/rapidsai/cudf/issues/13731) - x = x[n.index] - m2 = ((x - avg) ** 2).sum(**kwargs) - return n, avg, m2 - - @classmethod - def reduction_combine(cls, parts): - n, avg, m2 = parts[0] - for i in range(1, len(parts)): - n_a, avg_a, m2_a = n, avg, m2 - n_b, avg_b, m2_b = parts[i] - n = n_a + n_b - avg = (n_a * avg_a + n_b * avg_b) / n - delta = avg_b - avg_a - m2 = m2_a + m2_b + delta**2 * n_a * n_b / n - return n, avg, m2 - - @classmethod - def reduction_aggregate(cls, vals, ddof=1): - vals = cls.reduction_combine(vals) - n, _, m2 = vals - return m2 / (n - ddof) - - -def _patched_var( - self, axis=0, skipna=True, ddof=1, numeric_only=False, split_every=False -): - if axis == 0: - if hasattr(self._meta, "to_pandas"): - return VarCudf(self, skipna, ddof, numeric_only, split_every) - else: - return Var(self, skipna, ddof, numeric_only, split_every) - elif axis == 1: - return VarColumns(self, skipna, ddof, numeric_only) - else: - raise ValueError(f"axis={axis} not supported. Please specify 0 or 1") - - -Expr.var = _patched_var - - -# Temporary work-around for missing cudf + categorical support -# See: https://github.com/rapidsai/cudf/issues/11795 -# TODO: Fix RepartitionQuantiles and remove this in cudf>24.06 - -_original_get_divisions = _shuffle_module._get_divisions - - -def _patched_get_divisions(frame, other, *args, **kwargs): - # NOTE: The following two lines contains the "patch" - # (we simply convert the partitioning column to pandas) - if is_categorical_dtype(other._meta.dtype) and hasattr( - other.frame._meta, "to_pandas" - ): - other = new_collection(other).to_backend("pandas")._expr - - # Call "original" function - return _original_get_divisions(frame, other, *args, **kwargs) - - -_shuffle_module._get_divisions = _patched_get_divisions diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py deleted file mode 100644 index 8a16fe7615d..00000000000 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. - -from dask_expr._collection import new_collection -from dask_expr._groupby import ( - GroupBy as DXGroupBy, - SeriesGroupBy as DXSeriesGroupBy, - SingleAggregation, -) -from dask_expr._util import is_scalar - -from dask.dataframe.groupby import Aggregation - -from cudf.core.groupby.groupby import _deprecate_collect - -from dask_cudf.expr._expr import _maybe_get_custom_expr - -## -## Custom groupby classes -## - - -class ListAgg(SingleAggregation): - @staticmethod - def groupby_chunk(arg): - return arg.agg(list) - - @staticmethod - def groupby_aggregate(arg): - gb = arg.agg(list) - if gb.ndim > 1: - for col in gb.columns: - gb[col] = gb[col].list.concat() - return gb - else: - return gb.list.concat() - - -list_aggregation = Aggregation( - name="list", - chunk=ListAgg.groupby_chunk, - agg=ListAgg.groupby_aggregate, -) - - -def _translate_arg(arg): - # Helper function to translate args so that - # they can be processed correctly by upstream - # dask & dask-expr. Right now, the only necessary - # translation is list aggregations. - if isinstance(arg, dict): - return {k: _translate_arg(v) for k, v in arg.items()} - elif isinstance(arg, list): - return [_translate_arg(x) for x in arg] - elif arg in ("collect", "list", list): - return list_aggregation - else: - return arg - - -# We define our own GroupBy classes in Dask cuDF for -# the following reasons: -# (1) We want to use a custom `aggregate` algorithm -# that performs multiple aggregations on the -# same dataframe partition at once. The upstream -# algorithm breaks distinct aggregations into -# separate tasks. -# (2) We need to work around missing `observed=False` -# support: -# https://github.com/rapidsai/cudf/issues/15173 - - -class GroupBy(DXGroupBy): - def __init__(self, *args, observed=None, **kwargs): - observed = observed if observed is not None else True - super().__init__(*args, observed=observed, **kwargs) - - def __getitem__(self, key): - if is_scalar(key): - return SeriesGroupBy( - self.obj, - by=self.by, - slice=key, - sort=self.sort, - dropna=self.dropna, - observed=self.observed, - ) - g = GroupBy( - self.obj, - by=self.by, - slice=key, - sort=self.sort, - dropna=self.dropna, - observed=self.observed, - group_keys=self.group_keys, - ) - return g - - def collect(self, **kwargs): - _deprecate_collect() - return self._single_agg(ListAgg, **kwargs) - - def aggregate(self, arg, fused=True, **kwargs): - if ( - fused - and (expr := _maybe_get_custom_expr(self, arg, **kwargs)) - is not None - ): - return new_collection(expr) - else: - return super().aggregate(_translate_arg(arg), **kwargs) - - -class SeriesGroupBy(DXSeriesGroupBy): - def __init__(self, *args, observed=None, **kwargs): - observed = observed if observed is not None else True - super().__init__(*args, observed=observed, **kwargs) - - def collect(self, **kwargs): - _deprecate_collect() - return self._single_agg(ListAgg, **kwargs) - - def aggregate(self, arg, **kwargs): - return super().aggregate(_translate_arg(arg), **kwargs) diff --git a/python/dask_cudf/dask_cudf/io/__init__.py b/python/dask_cudf/dask_cudf/io/__init__.py index 0421bd755f4..1e0f24d78ce 100644 --- a/python/dask_cudf/dask_cudf/io/__init__.py +++ b/python/dask_cudf/dask_cudf/io/__init__.py @@ -1,11 +1,32 @@ -# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. -from .csv import read_csv # noqa: F401 -from .json import read_json # noqa: F401 -from .orc import read_orc, to_orc # noqa: F401 -from .text import read_text # noqa: F401 +from dask_cudf import _deprecated_api -try: - from .parquet import read_parquet, to_parquet # noqa: F401 -except ImportError: - pass +from . import csv, orc, json, parquet, text # noqa: F401 + + +read_csv = _deprecated_api( + "dask_cudf.io.read_csv", new_api="dask_cudf.read_csv" +) +read_json = _deprecated_api( + "dask_cudf.io.read_json", new_api="dask_cudf.read_json" +) +read_orc = _deprecated_api( + "dask_cudf.io.read_orc", new_api="dask_cudf.read_orc" +) +to_orc = _deprecated_api( + "dask_cudf.io.to_orc", + new_api="dask_cudf._legacy.io.to_orc", + rec="Please use the DataFrame.to_orc method instead.", +) +read_text = _deprecated_api( + "dask_cudf.io.read_text", new_api="dask_cudf.read_text" +) +read_parquet = _deprecated_api( + "dask_cudf.io.read_parquet", new_api="dask_cudf.read_parquet" +) +to_parquet = _deprecated_api( + "dask_cudf.io.to_parquet", + new_api="dask_cudf._legacy.io.parquet.to_parquet", + rec="Please use the DataFrame.to_parquet method instead.", +) diff --git a/python/dask_cudf/dask_cudf/io/csv.py b/python/dask_cudf/dask_cudf/io/csv.py index fa5400344f9..b22b31a591f 100644 --- a/python/dask_cudf/dask_cudf/io/csv.py +++ b/python/dask_cudf/dask_cudf/io/csv.py @@ -1,222 +1,8 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. -import os -from glob import glob -from warnings import warn +from dask_cudf import _deprecated_api -from fsspec.utils import infer_compression - -from dask import dataframe as dd -from dask.base import tokenize -from dask.dataframe.io.csv import make_reader -from dask.utils import apply, parse_bytes - -import cudf - - -def read_csv(path, blocksize="default", **kwargs): - """ - Read CSV files into a :class:`.DataFrame`. - - This API parallelizes the :func:`cudf:cudf.read_csv` function in - the following ways: - - It supports loading many files at once using globstrings: - - >>> import dask_cudf - >>> df = dask_cudf.read_csv("myfiles.*.csv") - - In some cases it can break up large files: - - >>> df = dask_cudf.read_csv("largefile.csv", blocksize="256 MiB") - - It can read CSV files from external resources (e.g. S3, HTTP, FTP) - - >>> df = dask_cudf.read_csv("s3://bucket/myfiles.*.csv") - >>> df = dask_cudf.read_csv("https://www.mycloud.com/sample.csv") - - Internally ``read_csv`` uses :func:`cudf:cudf.read_csv` and - supports many of the same keyword arguments with the same - performance guarantees. See the docstring for - :func:`cudf:cudf.read_csv` for more information on available - keyword arguments. - - Parameters - ---------- - path : str, path object, or file-like object - Either a path to a file (a str, :py:class:`pathlib.Path`, or - py._path.local.LocalPath), URL (including http, ftp, and S3 - locations), or any object with a read() method (such as - builtin :py:func:`open` file handler function or - :py:class:`~io.StringIO`). - blocksize : int or str, default "256 MiB" - The target task partition size. If ``None``, a single block - is used for each file. - **kwargs : dict - Passthrough key-word arguments that are sent to - :func:`cudf:cudf.read_csv`. - - Notes - ----- - If any of `skipfooter`/`skiprows`/`nrows` are passed, - `blocksize` will default to None. - - Examples - -------- - >>> import dask_cudf - >>> ddf = dask_cudf.read_csv("sample.csv", usecols=["a", "b"]) - >>> ddf.compute() - a b - 0 1 hi - 1 2 hello - 2 3 ai - - """ - - # Handle `chunksize` deprecation - if "chunksize" in kwargs: - chunksize = kwargs.pop("chunksize", "default") - warn( - "`chunksize` is deprecated and will be removed in the future. " - "Please use `blocksize` instead.", - FutureWarning, - ) - if blocksize == "default": - blocksize = chunksize - - # Set default `blocksize` - if blocksize == "default": - if ( - kwargs.get("skipfooter", 0) != 0 - or kwargs.get("skiprows", 0) != 0 - or kwargs.get("nrows", None) is not None - ): - # Cannot read in blocks if skipfooter, - # skiprows or nrows is passed. - blocksize = None - else: - blocksize = "256 MiB" - - if "://" in str(path): - func = make_reader(cudf.read_csv, "read_csv", "CSV") - return func(path, blocksize=blocksize, **kwargs) - else: - return _internal_read_csv(path=path, blocksize=blocksize, **kwargs) - - -def _internal_read_csv(path, blocksize="256 MiB", **kwargs): - if isinstance(blocksize, str): - blocksize = parse_bytes(blocksize) - - if isinstance(path, list): - filenames = path - elif isinstance(path, str): - filenames = sorted(glob(path)) - elif hasattr(path, "__fspath__"): - filenames = sorted(glob(path.__fspath__())) - else: - raise TypeError(f"Path type not understood:{type(path)}") - - if not filenames: - msg = f"A file in: {filenames} does not exist." - raise FileNotFoundError(msg) - - name = "read-csv-" + tokenize( - path, tokenize, **kwargs - ) # TODO: get last modified time - - compression = kwargs.get("compression", "infer") - - if compression == "infer": - # Infer compression from first path by default - compression = infer_compression(filenames[0]) - - if compression and blocksize: - # compressed CSVs reading must read the entire file - kwargs.pop("byte_range", None) - warn( - "Warning %s compression does not support breaking apart files\n" - "Please ensure that each individual file can fit in memory and\n" - "use the keyword ``blocksize=None to remove this message``\n" - "Setting ``blocksize=(size of file)``" % compression - ) - blocksize = None - - if blocksize is None: - return read_csv_without_blocksize(path, **kwargs) - - # Let dask.dataframe generate meta - dask_reader = make_reader(cudf.read_csv, "read_csv", "CSV") - kwargs1 = kwargs.copy() - usecols = kwargs1.pop("usecols", None) - dtype = kwargs1.pop("dtype", None) - meta = dask_reader(filenames[0], **kwargs1)._meta - names = meta.columns - if usecols or dtype: - # Regenerate meta with original kwargs if - # `usecols` or `dtype` was specified - meta = dask_reader(filenames[0], **kwargs)._meta - - dsk = {} - i = 0 - dtypes = meta.dtypes.values - - for fn in filenames: - size = os.path.getsize(fn) - for start in range(0, size, blocksize): - kwargs2 = kwargs.copy() - kwargs2["byte_range"] = ( - start, - blocksize, - ) # specify which chunk of the file we care about - if start != 0: - kwargs2["names"] = names # no header in the middle of the file - kwargs2["header"] = None - dsk[(name, i)] = (apply, _read_csv, [fn, dtypes], kwargs2) - - i += 1 - - divisions = [None] * (len(dsk) + 1) - return dd.core.new_dd_object(dsk, name, meta, divisions) - - -def _read_csv(fn, dtypes=None, **kwargs): - return cudf.read_csv(fn, **kwargs) - - -def read_csv_without_blocksize(path, **kwargs): - """Read entire CSV with optional compression (gzip/zip) - - Parameters - ---------- - path : str - path to files (support for glob) - """ - if isinstance(path, list): - filenames = path - elif isinstance(path, str): - filenames = sorted(glob(path)) - elif hasattr(path, "__fspath__"): - filenames = sorted(glob(path.__fspath__())) - else: - raise TypeError(f"Path type not understood:{type(path)}") - - name = "read-csv-" + tokenize(path, **kwargs) - - meta_kwargs = kwargs.copy() - if "skipfooter" in meta_kwargs: - meta_kwargs.pop("skipfooter") - if "nrows" in meta_kwargs: - meta_kwargs.pop("nrows") - # Read "head" of first file (first 5 rows). - # Convert to empty df for metadata. - meta = cudf.read_csv(filenames[0], nrows=5, **meta_kwargs).iloc[:0] - - graph = { - (name, i): (apply, cudf.read_csv, [fn], kwargs) - for i, fn in enumerate(filenames) - } - - divisions = [None] * (len(filenames) + 1) - - return dd.core.new_dd_object(graph, name, meta, divisions) +read_csv = _deprecated_api( + "dask_cudf.io.csv.read_csv", + new_api="dask_cudf.read_csv", +) diff --git a/python/dask_cudf/dask_cudf/io/json.py b/python/dask_cudf/dask_cudf/io/json.py index 98c5ceedb76..8f85ea54c0a 100644 --- a/python/dask_cudf/dask_cudf/io/json.py +++ b/python/dask_cudf/dask_cudf/io/json.py @@ -1,209 +1,8 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. -from functools import partial +from dask_cudf import _deprecated_api -import numpy as np -from fsspec.core import get_compression, get_fs_token_paths - -import dask -from dask.utils import parse_bytes - -import cudf -from cudf.core.column import as_column -from cudf.utils.ioutils import _is_local_filesystem - -from dask_cudf.backends import _default_backend - - -def _read_json_partition( - paths, - fs=None, - include_path_column=False, - path_converter=None, - **kwargs, -): - # Transfer all data up front for remote storage - sources = ( - paths - if fs is None - else fs.cat_ranges( - paths, - [0] * len(paths), - fs.sizes(paths), - ) - ) - - if include_path_column: - # Add "path" column. - # Must iterate over sources sequentially - if not isinstance(include_path_column, str): - include_path_column = "path" - converted_paths = ( - paths - if path_converter is None - else [path_converter(path) for path in paths] - ) - dfs = [] - for i, source in enumerate(sources): - df = cudf.read_json(source, **kwargs) - df[include_path_column] = as_column( - converted_paths[i], length=len(df) - ) - dfs.append(df) - return cudf.concat(dfs) - else: - # Pass sources directly to cudf - return cudf.read_json(sources, **kwargs) - - -def read_json( - url_path, - engine="auto", - blocksize=None, - orient="records", - lines=None, - compression="infer", - aggregate_files=True, - **kwargs, -): - """Read JSON data into a :class:`.DataFrame`. - - This function wraps :func:`dask.dataframe.read_json`, and passes - ``engine=partial(cudf.read_json, engine="auto")`` by default. - - Parameters - ---------- - url_path : str, list of str - Location to read from. If a string, can include a glob character to - find a set of file names. - Supports protocol specifications such as ``"s3://"``. - engine : str or Callable, default "auto" - - If str, this value will be used as the ``engine`` argument - when :func:`cudf.read_json` is used to create each partition. - If a :obj:`~collections.abc.Callable`, this value will be used as the - underlying function used to create each partition from JSON - data. The default value is "auto", so that - ``engine=partial(cudf.read_json, engine="auto")`` will be - passed to :func:`dask.dataframe.read_json` by default. - aggregate_files : bool or int - Whether to map multiple files to each output partition. If True, - the `blocksize` argument will be used to determine the number of - files in each partition. If any one file is larger than `blocksize`, - the `aggregate_files` argument will be ignored. If an integer value - is specified, the `blocksize` argument will be ignored, and that - number of files will be mapped to each partition. Default is True. - **kwargs : - Key-word arguments to pass through to :func:`dask.dataframe.read_json`. - - Returns - ------- - :class:`.DataFrame` - - Examples - -------- - Load single file - - >>> from dask_cudf import read_json - >>> read_json('myfile.json') # doctest: +SKIP - - Load large line-delimited JSON files using partitions of approx - 256MB size - - >>> read_json('data/file*.csv', blocksize=2**28) # doctest: +SKIP - - Load nested JSON data - - >>> read_json('myfile.json') # doctest: +SKIP - - See Also - -------- - dask.dataframe.read_json - - """ - - if lines is None: - lines = orient == "records" - if orient != "records" and lines: - raise ValueError( - 'Line-delimited JSON is only available with orient="records".' - ) - if blocksize and (orient != "records" or not lines): - raise ValueError( - "JSON file chunking only allowed for JSON-lines" - "input (orient='records', lines=True)." - ) - - inputs = [] - if aggregate_files and blocksize or int(aggregate_files) > 1: - # Attempt custom read if we are mapping multiple files - # to each output partition. Otherwise, upstream logic - # is sufficient. - - storage_options = kwargs.get("storage_options", {}) - fs, _, paths = get_fs_token_paths( - url_path, mode="rb", storage_options=storage_options - ) - if isinstance(aggregate_files, int) and aggregate_files > 1: - # Map a static file count to each partition - inputs = [ - paths[offset : offset + aggregate_files] - for offset in range(0, len(paths), aggregate_files) - ] - elif aggregate_files is True and blocksize: - # Map files dynamically (using blocksize) - file_sizes = fs.sizes(paths) # NOTE: This can be slow - blocksize = parse_bytes(blocksize) - if all([file_size <= blocksize for file_size in file_sizes]): - counts = np.unique( - np.floor(np.cumsum(file_sizes) / blocksize), - return_counts=True, - )[1] - offsets = np.concatenate([[0], counts.cumsum()]) - inputs = [ - paths[offsets[i] : offsets[i + 1]] - for i in range(len(offsets) - 1) - ] - - if inputs: - # Inputs were successfully populated. - # Use custom _read_json_partition function - # to generate each partition. - - compression = get_compression( - url_path[0] if isinstance(url_path, list) else url_path, - compression, - ) - _kwargs = dict( - orient=orient, - lines=lines, - compression=compression, - include_path_column=kwargs.get("include_path_column", False), - path_converter=kwargs.get("path_converter"), - ) - if not _is_local_filesystem(fs): - _kwargs["fs"] = fs - # TODO: Generate meta more efficiently - meta = _read_json_partition(inputs[0][:1], **_kwargs) - return dask.dataframe.from_map( - _read_json_partition, - inputs, - meta=meta, - **_kwargs, - ) - - # Fall back to dask.dataframe.read_json - return _default_backend( - dask.dataframe.read_json, - url_path, - engine=( - partial(cudf.read_json, engine=engine) - if isinstance(engine, str) - else engine - ), - blocksize=blocksize, - orient=orient, - lines=lines, - compression=compression, - **kwargs, - ) +read_json = _deprecated_api( + "dask_cudf.io.json.read_json", + new_api="dask_cudf.read_json", +) diff --git a/python/dask_cudf/dask_cudf/io/orc.py b/python/dask_cudf/dask_cudf/io/orc.py index bed69f038b0..5219cdacc31 100644 --- a/python/dask_cudf/dask_cudf/io/orc.py +++ b/python/dask_cudf/dask_cudf/io/orc.py @@ -1,199 +1,13 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. - -from io import BufferedWriter, IOBase - -from fsspec.core import get_fs_token_paths -from fsspec.utils import stringify_path -from pyarrow import orc as orc - -from dask import dataframe as dd -from dask.base import tokenize -from dask.dataframe.io.utils import _get_pyarrow_dtypes - -import cudf - - -def _read_orc_stripe(fs, path, stripe, columns, kwargs=None): - """Pull out specific columns from specific stripe""" - if kwargs is None: - kwargs = {} - with fs.open(path, "rb") as f: - df_stripe = cudf.read_orc( - f, stripes=[stripe], columns=columns, **kwargs - ) - return df_stripe - - -def read_orc(path, columns=None, filters=None, storage_options=None, **kwargs): - """Read ORC files into a :class:`.DataFrame`. - - Note that this function is mostly borrowed from upstream Dask. - - Parameters - ---------- - path : str or list[str] - Location of file(s), which can be a full URL with protocol specifier, - and may include glob character if a single string. - columns : None or list[str] - Columns to load. If None, loads all. - filters : None or list of tuple or list of lists of tuples - If not None, specifies a filter predicate used to filter out - row groups using statistics stored for each row group as - Parquet metadata. Row groups that do not match the given - filter predicate are not read. The predicate is expressed in - `disjunctive normal form (DNF) - `__ - like ``[[('x', '=', 0), ...], ...]``. DNF allows arbitrary - boolean logical combinations of single column predicates. The - innermost tuples each describe a single column predicate. The - list of inner predicates is interpreted as a conjunction - (AND), forming a more selective and multiple column predicate. - Finally, the outermost list combines these filters as a - disjunction (OR). Predicates may also be passed as a list of - tuples. This form is interpreted as a single conjunction. To - express OR in predicates, one must use the (preferred) - notation of list of lists of tuples. - storage_options : None or dict - Further parameters to pass to the bytes backend. - - See Also - -------- - dask.dataframe.read_orc - - Returns - ------- - dask_cudf.DataFrame - - """ - - storage_options = storage_options or {} - fs, fs_token, paths = get_fs_token_paths( - path, mode="rb", storage_options=storage_options - ) - schema = None - nstripes_per_file = [] - for path in paths: - with fs.open(path, "rb") as f: - o = orc.ORCFile(f) - if schema is None: - schema = o.schema - elif schema != o.schema: - raise ValueError( - "Incompatible schemas while parsing ORC files" - ) - nstripes_per_file.append(o.nstripes) - schema = _get_pyarrow_dtypes(schema, categories=None) - if columns is not None: - ex = set(columns) - set(schema) - if ex: - raise ValueError( - f"Requested columns ({ex}) not in schema ({set(schema)})" - ) - else: - columns = list(schema) - - with fs.open(paths[0], "rb") as f: - meta = cudf.read_orc( - f, - stripes=[0] if nstripes_per_file[0] else None, - columns=columns, - **kwargs, - ) - - name = "read-orc-" + tokenize(fs_token, path, columns, filters, **kwargs) - dsk = {} - N = 0 - for path, n in zip(paths, nstripes_per_file): - for stripe in ( - range(n) - if filters is None - else cudf.io.orc._filter_stripes(filters, path) - ): - dsk[(name, N)] = ( - _read_orc_stripe, - fs, - path, - stripe, - columns, - kwargs, - ) - N += 1 - - divisions = [None] * (len(dsk) + 1) - return dd.core.new_dd_object(dsk, name, meta, divisions) - - -def write_orc_partition(df, path, fs, filename, compression="snappy"): - full_path = fs.sep.join([path, filename]) - with fs.open(full_path, mode="wb") as out_file: - if not isinstance(out_file, IOBase): - out_file = BufferedWriter(out_file) - cudf.io.to_orc(df, out_file, compression=compression) - return full_path - - -def to_orc( - df, - path, - write_index=True, - storage_options=None, - compression="snappy", - compute=True, - **kwargs, -): - """ - Write a :class:`.DataFrame` to ORC file(s) (one file per partition). - - Parameters - ---------- - df : DataFrame - path : str or pathlib.Path - Destination directory for data. Prepend with protocol like ``s3://`` - or ``hdfs://`` for remote data. - write_index : boolean, optional - Whether or not to write the index. Defaults to True. - storage_options : None or dict - Further parameters to pass to the bytes backend. - compression : string or dict, optional - compute : bool, optional - If True (default) then the result is computed immediately. If - False then a :class:`~dask.delayed.Delayed` object is returned - for future computation. - - """ - - from dask import compute as dask_compute, delayed - - # TODO: Use upstream dask implementation once available - # (see: Dask Issue#5596) - - if hasattr(path, "name"): - path = stringify_path(path) - fs, _, _ = get_fs_token_paths( - path, mode="wb", storage_options=storage_options - ) - # Trim any protocol information from the path before forwarding - path = fs._strip_protocol(path) - - if write_index: - df = df.reset_index() - else: - # Not writing index - might as well drop it - df = df.reset_index(drop=True) - - fs.mkdirs(path, exist_ok=True) - - # Use i_offset and df.npartitions to define file-name list - filenames = ["part.%i.orc" % i for i in range(df.npartitions)] - - # write parts - dwrite = delayed(write_orc_partition) - parts = [ - dwrite(d, path, fs, filename, compression=compression) - for d, filename in zip(df.to_delayed(), filenames) - ] - - if compute: - return dask_compute(*parts) - - return delayed(list)(parts) +# Copyright (c) 2024, NVIDIA CORPORATION. + +from dask_cudf import _deprecated_api + +read_orc = _deprecated_api( + "dask_cudf.io.orc.read_orc", + new_api="dask_cudf.read_orc", +) +to_orc = _deprecated_api( + "dask_cudf.io.orc.to_orc", + new_api="dask_cudf._legacy.io.orc.to_orc", + rec="Please use the DataFrame.to_orc method instead.", +) diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 39ac6474958..48cea7266af 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -1,35 +1,66 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. -import itertools -import warnings -from functools import partial -from io import BufferedWriter, BytesIO, IOBase +# Copyright (c) 2024, NVIDIA CORPORATION. +import functools -import numpy as np import pandas as pd -from pyarrow import dataset as pa_ds, parquet as pq +from dask_expr.io.io import FusedParquetIO +from dask_expr.io.parquet import FragmentWrapper, ReadParquetPyarrowFS -from dask import dataframe as dd -from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine +import cudf -try: - from dask.dataframe.io.parquet import ( - create_metadata_file as create_metadata_file_dd, - ) -except ImportError: - create_metadata_file_dd = None +from dask_cudf import _deprecated_api + +# Dask-expr imports CudfEngine from this module +from dask_cudf._legacy.io.parquet import CudfEngine # noqa: F401 + + +class CudfFusedParquetIO(FusedParquetIO): + @staticmethod + def _load_multiple_files( + frag_filters, + columns, + schema, + *to_pandas_args, + ): + import pyarrow as pa + + from dask.base import apply, tokenize + from dask.threaded import get + + token = tokenize(frag_filters, columns, schema) + name = f"pq-file-{token}" + dsk = { + (name, i): ( + CudfReadParquetPyarrowFS._fragment_to_table, + frag, + filter, + columns, + schema, + ) + for i, (frag, filter) in enumerate(frag_filters) + } + dsk[name] = ( + apply, + pa.concat_tables, + [list(dsk.keys())], + {"promote_options": "permissive"}, + ) + return CudfReadParquetPyarrowFS._table_to_pandas( + get(dsk, name), + *to_pandas_args, + ) -import cudf -from cudf.core.column import CategoricalColumn, as_column -from cudf.io import write_to_dataset -from cudf.io.parquet import _apply_post_filters, _normalize_filters -from cudf.utils.dtypes import cudf_dtype_from_pa_type +class CudfReadParquetPyarrowFS(ReadParquetPyarrowFS): + @functools.cached_property + def _dataset_info(self): + from dask_cudf._legacy.io.parquet import ( + set_object_dtypes_from_pa_schema, + ) -class CudfEngine(ArrowDatasetEngine): - @classmethod - def _create_dd_meta(cls, dataset_info, **kwargs): - # Start with pandas-version of meta - meta_pd = super()._create_dd_meta(dataset_info, **kwargs) + dataset_info = super()._dataset_info + meta_pd = dataset_info["base_meta"] + if isinstance(meta_pd, cudf.DataFrame): + return dataset_info # Convert to cudf # (drop unsupported timezone information) @@ -45,469 +76,60 @@ def _create_dd_meta(cls, dataset_info, **kwargs): kwargs.get("schema", None), ) - return meta_cudf - - @classmethod - def multi_support(cls): - # Assert that this class is CudfEngine - # and that multi-part reading is supported - return cls == CudfEngine - - @classmethod - def _read_paths( - cls, - paths, - fs, - columns=None, - row_groups=None, - filters=None, - partitions=None, - partitioning=None, - partition_keys=None, - open_file_options=None, - dataset_kwargs=None, - **kwargs, - ): - # Simplify row_groups if all None - if row_groups == [None for path in paths]: - row_groups = None - - # Make sure we read in the columns needed for row-wise - # filtering after IO. This means that one or more columns - # will be dropped almost immediately after IO. However, - # we do NEED these columns for accurate filtering. - filters = _normalize_filters(filters) - projected_columns = None - if columns and filters: - projected_columns = [c for c in columns if c is not None] - columns = sorted( - set(v[0] for v in itertools.chain.from_iterable(filters)) - | set(projected_columns) - ) - - dataset_kwargs = dataset_kwargs or {} - dataset_kwargs["partitioning"] = partitioning or "hive" - - # Use cudf to read in data - try: - df = cudf.read_parquet( - paths, - engine="cudf", - columns=columns, - row_groups=row_groups if row_groups else None, - dataset_kwargs=dataset_kwargs, - categorical_partitions=False, - filesystem=fs, - **kwargs, - ) - except RuntimeError as err: - # TODO: Remove try/except after null-schema issue is resolved - # (See: https://github.com/rapidsai/cudf/issues/12702) - if len(paths) > 1: - df = cudf.concat( - [ - cudf.read_parquet( - path, - engine="cudf", - columns=columns, - row_groups=row_groups[i] if row_groups else None, - dataset_kwargs=dataset_kwargs, - categorical_partitions=False, - filesystem=fs, - **kwargs, - ) - for i, path in enumerate(paths) - ] - ) - else: - raise err - - # Apply filters (if any are defined) - df = _apply_post_filters(df, filters) - - if projected_columns: - # Elements of `projected_columns` may now be in the index. - # We must filter these names from our projection - projected_columns = [ - col for col in projected_columns if col in df._column_names - ] - df = df[projected_columns] - - if partitions and partition_keys is None: - # Use `HivePartitioning` by default - ds = pa_ds.dataset( - paths, - filesystem=fs, - **dataset_kwargs, - ) - frag = next(ds.get_fragments()) - if frag: - # Extract hive-partition keys, and make sure they - # are ordered the same as they are in `partitions` - raw_keys = pa_ds._get_partition_keys(frag.partition_expression) - partition_keys = [ - (hive_part.name, raw_keys[hive_part.name]) - for hive_part in partitions - ] - - if partition_keys: - if partitions is None: - raise ValueError("Must pass partition sets") - - for i, (name, index2) in enumerate(partition_keys): - if len(partitions[i].keys): - # Build a categorical column from `codes` directly - # (since the category is often a larger dtype) - codes = as_column( - partitions[i].keys.get_loc(index2), - length=len(df), - ) - df[name] = CategoricalColumn( - data=None, - size=codes.size, - dtype=cudf.CategoricalDtype( - categories=partitions[i].keys, ordered=False - ), - offset=codes.offset, - children=(codes,), - ) - elif name not in df.columns: - # Add non-categorical partition column - df[name] = as_column(index2, length=len(df)) - - return df - - @classmethod - def read_partition( - cls, - fs, - pieces, - columns, - index, - categories=(), - partitions=(), - filters=None, - partitioning=None, - schema=None, - open_file_options=None, - **kwargs, - ): - if columns is not None: - columns = [c for c in columns] - if isinstance(index, list): - columns += index - - dataset_kwargs = kwargs.get("dataset", {}) - partitioning = partitioning or dataset_kwargs.get("partitioning", None) - if isinstance(partitioning, dict): - partitioning = pa_ds.partitioning(**partitioning) - - # Check if we are actually selecting any columns - read_columns = columns - if schema and columns: - ignored = set(schema.names) - set(columns) - if not ignored: - read_columns = None - - if not isinstance(pieces, list): - pieces = [pieces] - - # Extract supported kwargs from `kwargs` - read_kwargs = kwargs.get("read", {}) - read_kwargs.update(open_file_options or {}) - check_file_size = read_kwargs.pop("check_file_size", None) - - # Wrap reading logic in a `try` block so that we can - # inform the user that the `read_parquet` partition - # size is too large for the available memory - try: - # Assume multi-piece read - paths = [] - rgs = [] - last_partition_keys = None - dfs = [] - - for i, piece in enumerate(pieces): - (path, row_group, partition_keys) = piece - row_group = None if row_group == [None] else row_group - - # File-size check to help "protect" users from change - # to up-stream `split_row_groups` default. We only - # check the file size if this partition corresponds - # to a full file, and `check_file_size` is defined - if check_file_size and len(pieces) == 1 and row_group is None: - file_size = fs.size(path) - if file_size > check_file_size: - warnings.warn( - f"A large parquet file ({file_size}B) is being " - f"used to create a DataFrame partition in " - f"read_parquet. This may cause out of memory " - f"exceptions in operations downstream. See the " - f"notes on split_row_groups in the read_parquet " - f"documentation. Setting split_row_groups " - f"explicitly will silence this warning." - ) - - if i > 0 and partition_keys != last_partition_keys: - dfs.append( - cls._read_paths( - paths, - fs, - columns=read_columns, - row_groups=rgs if rgs else None, - filters=filters, - partitions=partitions, - partitioning=partitioning, - partition_keys=last_partition_keys, - dataset_kwargs=dataset_kwargs, - **read_kwargs, - ) - ) - paths = [] - rgs = [] - last_partition_keys = None - paths.append(path) - rgs.append( - [row_group] - if not isinstance(row_group, list) - and row_group is not None - else row_group - ) - last_partition_keys = partition_keys - - dfs.append( - cls._read_paths( - paths, - fs, - columns=read_columns, - row_groups=rgs if rgs else None, - filters=filters, - partitions=partitions, - partitioning=partitioning, - partition_keys=last_partition_keys, - dataset_kwargs=dataset_kwargs, - **read_kwargs, - ) - ) - df = cudf.concat(dfs) if len(dfs) > 1 else dfs[0] - - # Re-set "object" dtypes align with pa schema - set_object_dtypes_from_pa_schema(df, schema) - - if index and (index[0] in df.columns): - df = df.set_index(index[0]) - elif index is False and df.index.names != [None]: - # If index=False, we shouldn't have a named index - df.reset_index(inplace=True) - - except MemoryError as err: - raise MemoryError( - "Parquet data was larger than the available GPU memory!\n\n" - "See the notes on split_row_groups in the read_parquet " - "documentation.\n\n" - "Original Error: " + str(err) - ) - raise err - - return df - - @staticmethod - def write_partition( - df, - path, - fs, - filename, - partition_on, - return_metadata, - fmd=None, - compression="snappy", - index_cols=None, - **kwargs, - ): - preserve_index = False - if len(index_cols) and set(index_cols).issubset(set(df.columns)): - df.set_index(index_cols, drop=True, inplace=True) - preserve_index = True - if partition_on: - md = write_to_dataset( - df=df, - root_path=path, - compression=compression, - filename=filename, - partition_cols=partition_on, - fs=fs, - preserve_index=preserve_index, - return_metadata=return_metadata, - statistics=kwargs.get("statistics", "ROWGROUP"), - int96_timestamps=kwargs.get("int96_timestamps", False), - row_group_size_bytes=kwargs.get("row_group_size_bytes", None), - row_group_size_rows=kwargs.get("row_group_size_rows", None), - max_page_size_bytes=kwargs.get("max_page_size_bytes", None), - max_page_size_rows=kwargs.get("max_page_size_rows", None), - storage_options=kwargs.get("storage_options", None), - ) - else: - with fs.open(fs.sep.join([path, filename]), mode="wb") as out_file: - if not isinstance(out_file, IOBase): - out_file = BufferedWriter(out_file) - md = df.to_parquet( - path=out_file, - engine=kwargs.get("engine", "cudf"), - index=kwargs.get("index", None), - partition_cols=kwargs.get("partition_cols", None), - partition_file_name=kwargs.get( - "partition_file_name", None - ), - partition_offsets=kwargs.get("partition_offsets", None), - statistics=kwargs.get("statistics", "ROWGROUP"), - int96_timestamps=kwargs.get("int96_timestamps", False), - row_group_size_bytes=kwargs.get( - "row_group_size_bytes", None - ), - row_group_size_rows=kwargs.get( - "row_group_size_rows", None - ), - storage_options=kwargs.get("storage_options", None), - metadata_file_path=filename if return_metadata else None, - ) - # Return the schema needed to write the metadata - if return_metadata: - return [{"meta": md}] - else: - return [] + dataset_info["base_meta"] = meta_cudf + self.operands[type(self)._parameters.index("_dataset_info_cache")] = ( + dataset_info + ) + return dataset_info @staticmethod - def write_metadata(parts, fmd, fs, path, append=False, **kwargs): - if parts: - # Aggregate metadata and write to _metadata file - metadata_path = fs.sep.join([path, "_metadata"]) - _meta = [] - if append and fmd is not None: - # Convert to bytes: - if isinstance(fmd, pq.FileMetaData): - with BytesIO() as myio: - fmd.write_metadata_file(myio) - myio.seek(0) - fmd = np.frombuffer(myio.read(), dtype="uint8") - _meta = [fmd] - _meta.extend([parts[i][0]["meta"] for i in range(len(parts))]) - _meta = ( - cudf.io.merge_parquet_filemetadata(_meta) - if len(_meta) > 1 - else _meta[0] - ) - with fs.open(metadata_path, "wb") as fil: - fil.write(memoryview(_meta)) - - @classmethod - def collect_file_metadata(cls, path, fs, file_path): - with fs.open(path, "rb") as f: - meta = pq.ParquetFile(f).metadata - if file_path: - meta.set_file_path(file_path) - with BytesIO() as myio: - meta.write_metadata_file(myio) - myio.seek(0) - meta = np.frombuffer(myio.read(), dtype="uint8") - return meta + def _table_to_pandas(table, index_name): + df = cudf.DataFrame.from_arrow(table) + if index_name is not None: + df = df.set_index(index_name) + return df - @classmethod - def aggregate_metadata(cls, meta_list, fs, out_path): - meta = ( - cudf.io.merge_parquet_filemetadata(meta_list) - if len(meta_list) > 1 - else meta_list[0] + def _filtered_task(self, index: int): + columns = self.columns.copy() + index_name = self.index.name + if self.index is not None: + index_name = self.index.name + schema = self._dataset_info["schema"].remove_metadata() + if index_name: + if columns is None: + columns = list(schema.names) + columns.append(index_name) + return ( + self._table_to_pandas, + ( + self._fragment_to_table, + FragmentWrapper(self.fragments[index], filesystem=self.fs), + self.filters, + columns, + schema, + ), + index_name, ) - if out_path: - metadata_path = fs.sep.join([out_path, "_metadata"]) - with fs.open(metadata_path, "wb") as fil: - fil.write(memoryview(meta)) - return None - else: - return meta - - -def set_object_dtypes_from_pa_schema(df, schema): - # Simple utility to modify cudf DataFrame - # "object" dtypes to agree with a specific - # pyarrow schema. - if schema: - for col_name, col in df._data.items(): - if col_name is None: - # Pyarrow cannot handle `None` as a field name. - # However, this should be a simple range index that - # we can ignore anyway - continue - typ = cudf_dtype_from_pa_type(schema.field(col_name).type) - if ( - col_name in schema.names - and not isinstance(typ, (cudf.ListDtype, cudf.StructDtype)) - and isinstance(col, cudf.core.column.StringColumn) - ): - df._data[col_name] = col.astype(typ) - - -def read_parquet(path, columns=None, **kwargs): - """ - Read parquet files into a :class:`.DataFrame`. - - Calls :func:`dask.dataframe.read_parquet` with ``engine=CudfEngine`` - to coordinate the execution of :func:`cudf.read_parquet`, and to - ultimately create a :class:`.DataFrame` collection. - - See the :func:`dask.dataframe.read_parquet` documentation for - all available options. - - Examples - -------- - >>> from dask_cudf import read_parquet - >>> df = read_parquet("/path/to/dataset/") # doctest: +SKIP - - When dealing with one or more large parquet files having an - in-memory footprint >15% device memory, the ``split_row_groups`` - argument should be used to map Parquet **row-groups** to DataFrame - partitions (instead of **files** to partitions). For example, the - following code will map each row-group to a distinct partition: - - >>> df = read_parquet(..., split_row_groups=True) # doctest: +SKIP - - To map **multiple** row-groups to each partition, an integer can be - passed to ``split_row_groups`` to specify the **maximum** number of - row-groups allowed in each output partition: - - >>> df = read_parquet(..., split_row_groups=10) # doctest: +SKIP - - See Also - -------- - cudf.read_parquet - dask.dataframe.read_parquet - """ - if isinstance(columns, str): - columns = [columns] - - # Set "check_file_size" option to determine whether we - # should check the parquet-file size. This check is meant - # to "protect" users from `split_row_groups` default changes - check_file_size = kwargs.pop("check_file_size", 500_000_000) - if ( - check_file_size - and ("split_row_groups" not in kwargs) - and ("chunksize" not in kwargs) - ): - # User is not specifying `split_row_groups` or `chunksize`, - # so we should warn them if/when a file is ~>0.5GB on disk. - # They can set `split_row_groups` explicitly to silence/skip - # this check - if "read" not in kwargs: - kwargs["read"] = {} - kwargs["read"]["check_file_size"] = check_file_size - - return dd.read_parquet(path, columns=columns, engine=CudfEngine, **kwargs) - - -to_parquet = partial(dd.to_parquet, engine=CudfEngine) -if create_metadata_file_dd is None: - create_metadata_file = create_metadata_file_dd -else: - create_metadata_file = partial(create_metadata_file_dd, engine=CudfEngine) + def _tune_up(self, parent): + if self._fusion_compression_factor >= 1: + return + if isinstance(parent, CudfFusedParquetIO): + return + return parent.substitute(self, CudfFusedParquetIO(self)) + + +read_parquet = _deprecated_api( + "dask_cudf.io.parquet.read_parquet", + new_api="dask_cudf.read_parquet", +) +to_parquet = _deprecated_api( + "dask_cudf.io.parquet.to_parquet", + new_api="dask_cudf._legacy.io.parquet.to_parquet", + rec="Please use the DataFrame.to_parquet method instead.", +) +create_metadata_file = _deprecated_api( + "dask_cudf.io.parquet.create_metadata_file", + new_api="dask_cudf._legacy.io.parquet.create_metadata_file", + rec="Please raise an issue if this feature is needed.", +) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_csv.py b/python/dask_cudf/dask_cudf/io/tests/test_csv.py index a35a9f1be48..a0acb86f5a9 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_csv.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_csv.py @@ -264,3 +264,18 @@ def test_read_csv_nrows_error(csv_end_bad_lines): dask_cudf.read_csv( csv_end_bad_lines, nrows=2, blocksize="100 MiB" ).compute() + + +def test_deprecated_api_paths(tmp_path): + csv_path = str(tmp_path / "data-*.csv") + df = dask_cudf.DataFrame.from_dict({"a": range(100)}, npartitions=1) + df.to_csv(csv_path, index=False) + + # Encourage top-level read_csv import only + with pytest.warns(match="dask_cudf.io.read_csv is now deprecated"): + df2 = dask_cudf.io.read_csv(csv_path) + dd.assert_eq(df, df2, check_divisions=False) + + with pytest.warns(match="dask_cudf.io.csv.read_csv is now deprecated"): + df2 = dask_cudf.io.csv.read_csv(csv_path) + dd.assert_eq(df, df2, check_divisions=False) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_json.py b/python/dask_cudf/dask_cudf/io/tests/test_json.py index abafbffd197..f5509cf91c3 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_json.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_json.py @@ -126,3 +126,18 @@ def test_read_json_aggregate_files(tmp_path): assert name in df2.columns assert len(df2[name].compute().unique()) == df1.npartitions dd.assert_eq(df1, df2.drop(columns=[name]), check_index=False) + + +def test_deprecated_api_paths(tmp_path): + path = str(tmp_path / "data-*.json") + df = dd.from_dict({"a": range(100)}, npartitions=1) + df.to_json(path) + + # Encourage top-level read_json import only + with pytest.warns(match="dask_cudf.io.read_json is now deprecated"): + df2 = dask_cudf.io.read_json(path) + dd.assert_eq(df, df2, check_divisions=False) + + with pytest.warns(match="dask_cudf.io.json.read_json is now deprecated"): + df2 = dask_cudf.io.json.read_json(path) + dd.assert_eq(df, df2, check_divisions=False) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_orc.py b/python/dask_cudf/dask_cudf/io/tests/test_orc.py index 457e5546bd9..b6064d851ca 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_orc.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_orc.py @@ -145,3 +145,21 @@ def test_to_orc(tmpdir, dtypes, compression, compute): # the cudf dataframes (df and df_read) dd.assert_eq(df, ddf_read) dd.assert_eq(df_read, ddf_read) + + +def test_deprecated_api_paths(tmpdir): + df = dask_cudf.DataFrame.from_dict({"a": range(100)}, npartitions=1) + path = tmpdir.join("test.orc") + # Top-level to_orc function is deprecated + with pytest.warns(match="dask_cudf.to_orc is now deprecated"): + dask_cudf.to_orc(df, path, write_index=False) + + # Encourage top-level read_orc import only + paths = glob.glob(str(path) + "/*.orc") + with pytest.warns(match="dask_cudf.io.read_orc is now deprecated"): + df2 = dask_cudf.io.read_orc(paths) + dd.assert_eq(df, df2, check_divisions=False) + + with pytest.warns(match="dask_cudf.io.orc.read_orc is now deprecated"): + df2 = dask_cudf.io.orc.read_orc(paths) + dd.assert_eq(df, df2, check_divisions=False) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py index a29cf9a342a..522a21e12a5 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py @@ -15,6 +15,7 @@ import cudf import dask_cudf +from dask_cudf._legacy.io.parquet import create_metadata_file from dask_cudf.tests.utils import ( require_dask_expr, skip_dask_expr, @@ -24,7 +25,7 @@ # Check if create_metadata_file is supported by # the current dask.dataframe version need_create_meta = pytest.mark.skipif( - dask_cudf.io.parquet.create_metadata_file is None, + create_metadata_file is None, reason="Need create_metadata_file support in dask.dataframe.", ) @@ -425,10 +426,14 @@ def test_create_metadata_file(tmpdir, partition_on): fns = glob.glob(os.path.join(tmpdir, partition_on + "=*/*.parquet")) else: fns = glob.glob(os.path.join(tmpdir, "*.parquet")) - dask_cudf.io.parquet.create_metadata_file( - fns, - split_every=3, # Force tree reduction - ) + + with pytest.warns( + match="dask_cudf.io.parquet.create_metadata_file is now deprecated" + ): + dask_cudf.io.parquet.create_metadata_file( + fns, + split_every=3, # Force tree reduction + ) # Check that we can now read the ddf # with the _metadata file present @@ -472,7 +477,7 @@ def test_create_metadata_file_inconsistent_schema(tmpdir): # Add global metadata file. # Dask-CuDF can do this without requiring schema # consistency. - dask_cudf.io.parquet.create_metadata_file([p0, p1]) + create_metadata_file([p0, p1]) # Check that we can still read the ddf # with the _metadata file present @@ -533,9 +538,9 @@ def test_check_file_size(tmpdir): fn = str(tmpdir.join("test.parquet")) cudf.DataFrame({"a": np.arange(1000)}).to_parquet(fn) with pytest.warns(match="large parquet file"): - # Need to use `dask_cudf.io` path + # Need to use `dask_cudf._legacy.io` path # TODO: Remove outdated `check_file_size` functionality - dask_cudf.io.read_parquet(fn, check_file_size=1).compute() + dask_cudf._legacy.io.read_parquet(fn, check_file_size=1).compute() @xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="2024.3.0") @@ -664,3 +669,21 @@ def test_to_parquet_append(tmpdir, write_metadata_file): ) ddf2 = dask_cudf.read_parquet(tmpdir) dd.assert_eq(cudf.concat([df, df]), ddf2) + + +def test_deprecated_api_paths(tmpdir): + df = dask_cudf.DataFrame.from_dict({"a": range(100)}, npartitions=1) + # io.to_parquet function is deprecated + with pytest.warns(match="dask_cudf.io.to_parquet is now deprecated"): + dask_cudf.io.to_parquet(df, tmpdir) + + # Encourage top-level read_parquet import only + with pytest.warns(match="dask_cudf.io.read_parquet is now deprecated"): + df2 = dask_cudf.io.read_parquet(tmpdir) + dd.assert_eq(df, df2, check_divisions=False) + + with pytest.warns( + match="dask_cudf.io.parquet.read_parquet is now deprecated" + ): + df2 = dask_cudf.io.parquet.read_parquet(tmpdir) + dd.assert_eq(df, df2, check_divisions=False) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_text.py b/python/dask_cudf/dask_cudf/io/tests/test_text.py index 8912b7d5da6..e35b6411a9d 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_text.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_text.py @@ -34,3 +34,15 @@ def test_read_text_byte_range(offset, size): text_file, chunksize=None, delimiter=".", byte_range=(offset, size) ) dd.assert_eq(df1, df2, check_index=False) + + +def test_deprecated_api_paths(): + # Encourage top-level read_text import only + df = cudf.read_text(text_file, delimiter=".") + with pytest.warns(match="dask_cudf.io.read_text is now deprecated"): + df2 = dask_cudf.io.read_text(text_file, delimiter=".") + dd.assert_eq(df, df2, check_divisions=False) + + with pytest.warns(match="dask_cudf.io.text.read_text is now deprecated"): + df2 = dask_cudf.io.text.read_text(text_file, delimiter=".") + dd.assert_eq(df, df2, check_divisions=False) diff --git a/python/dask_cudf/dask_cudf/io/text.py b/python/dask_cudf/dask_cudf/io/text.py index 9cdb7c5220b..1caf4e81d8e 100644 --- a/python/dask_cudf/dask_cudf/io/text.py +++ b/python/dask_cudf/dask_cudf/io/text.py @@ -1,54 +1,8 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. -import os -from glob import glob +from dask_cudf import _deprecated_api -import dask.dataframe as dd -from dask.base import tokenize -from dask.utils import apply, parse_bytes - -import cudf - - -def read_text(path, chunksize="256 MiB", **kwargs): - if isinstance(chunksize, str): - chunksize = parse_bytes(chunksize) - - if isinstance(path, list): - filenames = path - elif isinstance(path, str): - filenames = sorted(glob(path)) - elif hasattr(path, "__fspath__"): - filenames = sorted(glob(path.__fspath__())) - else: - raise TypeError(f"Path type not understood:{type(path)}") - - if not filenames: - msg = f"A file in: {filenames} does not exist." - raise FileNotFoundError(msg) - - name = "read-text-" + tokenize(path, tokenize, **kwargs) - - if chunksize: - dsk = {} - i = 0 - for fn in filenames: - size = os.path.getsize(fn) - for start in range(0, size, chunksize): - kwargs1 = kwargs.copy() - kwargs1["byte_range"] = ( - start, - chunksize, - ) # specify which chunk of the file we care about - - dsk[(name, i)] = (apply, cudf.read_text, [fn], kwargs1) - i += 1 - else: - dsk = { - (name, i): (apply, cudf.read_text, [fn], kwargs) - for i, fn in enumerate(filenames) - } - - meta = cudf.Series([], dtype="O") - divisions = [None] * (len(dsk) + 1) - return dd.core.new_dd_object(dsk, name, meta, divisions) +read_text = _deprecated_api( + "dask_cudf.io.text.read_text", + new_api="dask_cudf.read_text", +) diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 8e42c847ddf..5130b804179 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -39,30 +39,6 @@ def test_from_dict_backend_dispatch(): dd.assert_eq(expect, ddf) -def test_to_dask_dataframe_deprecated(): - gdf = cudf.DataFrame({"a": range(100)}) - ddf = dd.from_pandas(gdf, npartitions=2) - assert isinstance(ddf._meta, cudf.DataFrame) - - with pytest.warns(FutureWarning, match="API is now deprecated"): - assert isinstance( - ddf.to_dask_dataframe()._meta, - pd.DataFrame, - ) - - -def test_from_dask_dataframe_deprecated(): - gdf = pd.DataFrame({"a": range(100)}) - ddf = dd.from_pandas(gdf, npartitions=2) - assert isinstance(ddf._meta, pd.DataFrame) - - with pytest.warns(FutureWarning, match="API is now deprecated"): - assert isinstance( - dask_cudf.from_dask_dataframe(ddf)._meta, - cudf.DataFrame, - ) - - def test_to_backend(): rng = np.random.default_rng(seed=0) data = { diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 042e69d86f4..918290aa6fa 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -13,7 +13,7 @@ from cudf.testing._utils import expect_warning_if import dask_cudf -from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized +from dask_cudf._legacy.groupby import OPTIMIZED_AGGS, _aggs_optimized from dask_cudf.tests.utils import ( QUERY_PLANNING_ON, require_dask_expr, diff --git a/python/dask_cudf/dask_cudf/tests/utils.py b/python/dask_cudf/dask_cudf/tests/utils.py index 9aaf6dc8420..a9f61f75762 100644 --- a/python/dask_cudf/dask_cudf/tests/utils.py +++ b/python/dask_cudf/dask_cudf/tests/utils.py @@ -10,7 +10,7 @@ import cudf -from dask_cudf.expr import QUERY_PLANNING_ON +from dask_cudf import QUERY_PLANNING_ON if QUERY_PLANNING_ON: DASK_VERSION = Version(dask.__version__)