Skip to content

Commit

Permalink
Merge branch 'main' into use-to_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 authored Jul 26, 2024
2 parents b85336d + 2643923 commit 2053315
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 659 deletions.
67 changes: 67 additions & 0 deletions merlin/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

_DASK_QUERY_PLANNING_ENABLED = False
try:
# Disable query-planning and string conversion
import dask

dask.config.set(
{
"dataframe.query-planning": False,
"dataframe.convert-string": False,
}
)
except ImportError:
dask = None
else:
import sys

import dask.dataframe as dd
from packaging.version import parse

if parse(dask.__version__) > parse("2024.6.0"):
# For newer versions of dask, we can just check
# the official DASK_EXPR_ENABLED constant
_DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED
else:
# For older versions of dask, we must assume query
# planning is enabled if dask_expr was imported
# (because we can't know for sure)
_DASK_QUERY_PLANNING_ENABLED = "dask_expr" in sys.modules


def validate_dask_configs():
"""Central check for problematic config options in Dask"""
if _DASK_QUERY_PLANNING_ENABLED:
raise NotImplementedError(
"Merlin does not support the query-planning API in "
"Dask Dataframe yet. Please make sure query-planning is "
"disabled before dask.dataframe is imported.\n\ne.g."
"dask.config.set({'dataframe.query-planning': False})"
"\n\nOr set the environment variable: "
"export DASK_DATAFRAME__QUERY_PLANNING=False"
)

if dask is not None and dask.config.get("dataframe.convert-string"):
raise NotImplementedError(
"Merlin does not support automatic string conversion in "
"Dask Dataframe yet. Please make sure this option is "
"disabled.\n\ne.g."
"dask.config.set({'dataframe.convert-string': False})"
"\n\nOr set the environment variable: "
"export DASK_DATAFRAME__CONVERT_STRING=False"
)
4 changes: 3 additions & 1 deletion merlin/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,8 @@
# limitations under the License.
#

from merlin.config import validate_dask_configs
from merlin.core import _version

__version__ = _version.get_versions()["version"]
validate_dask_configs()
6 changes: 5 additions & 1 deletion merlin/dag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,10 @@
#
# flake8: noqa

from merlin.config import validate_dask_configs

validate_dask_configs()

from merlin.dag.graph import Graph
from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.operator import DataFormats, Operator, Supports
Expand Down
8 changes: 6 additions & 2 deletions merlin/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# flake8: noqa

from merlin.config import validate_dask_configs

validate_dask_configs()

from merlin.io import dataframe_iter, dataset, shuffle
from merlin.io.dataframe_iter import DataFrameIter
from merlin.io.dataset import MERLIN_METADATA_DIR_NAME, Dataset
Expand Down
99 changes: 3 additions & 96 deletions merlin/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,103 +1130,10 @@ def npartitions(self):
return self.to_ddf().npartitions

def validate_dataset(self, **kwargs):
"""Validate for efficient processing.
raise NotImplementedError(""" validate_dataset is not supported for merlin >23.08 """)

The purpose of this method is to validate that the Dataset object
meets the minimal requirements for efficient NVTabular processing.
For now, this criteria requires the data to be in parquet format.
Example Usage::
dataset = Dataset("/path/to/data_pq", engine="parquet")
assert validate_dataset(dataset)
Parameters
-----------
**kwargs :
Key-word arguments to pass down to the engine's validate_dataset
method. For the recommended parquet format, these arguments
include `add_metadata_file`, `row_group_max_size`, `file_min_size`,
and `require_metadata_file`. For more information, see
`ParquetDatasetEngine.validate_dataset`.
Returns
-------
valid : bool
`True` if the input dataset is valid for efficient NVTabular
processing.
"""

# Check that the dataset format is Parquet
if not isinstance(self.engine, ParquetDatasetEngine):
msg = (
"NVTabular is optimized for the parquet format. Please use "
"the to_parquet method to convert your dataset."
)
warnings.warn(msg)
return False # Early return

return self.engine.validate_dataset(**kwargs)

def regenerate_dataset(
self,
output_path,
columns=None,
output_format="parquet",
compute=True,
**kwargs,
):
"""EXPERIMENTAL:
Regenerate an NVTabular Dataset for efficient processing by writing
out new Parquet files. In contrast to default ``to_parquet`` behavior,
this method preserves the original ordering.
Example Usage::
dataset = Dataset("/path/to/data_pq", engine="parquet")
dataset.regenerate_dataset(
out_path, part_size="1MiB", file_size="10MiB"
)
Parameters
-----------
output_path : string
Root directory path to use for the new (regenerated) dataset.
columns : list(string), optional
Subset of columns to include in the regenerated dataset.
output_format : string, optional
Format to use for regenerated dataset. Only "parquet" (default)
is currently supported.
compute : bool, optional
Whether to compute the task graph or to return a Delayed object.
By default, the graph will be executed.
**kwargs :
Key-word arguments to pass down to the engine's regenerate_dataset
method. See `ParquetDatasetEngine.regenerate_dataset` for more
information.
Returns
-------
result : int or Delayed
If `compute=True` (default), the return value will be an integer
corresponding to the number of generated data files. If `False`,
the returned value will be a `Delayed` object.
"""

# Check that the desired output format is Parquet
if output_format not in ["parquet"]:
msg = (
f"NVTabular is optimized for the parquet format. "
f"{output_format} is not yet a supported output format for "
f"regenerate_dataset."
)
raise ValueError(msg)

result = ParquetDatasetEngine.regenerate_dataset(self, output_path, columns=None, **kwargs)
if compute:
return result.compute()
else:
return result
def regenerate_dataset(self, *args, **kwargs):
raise NotImplementedError(""" regenerate_dataset is not supported for merlin >23.08 """)

def infer_schema(self, n=1):
"""Create a schema containing the column names and inferred dtypes of the Dataset
Expand Down
7 changes: 0 additions & 7 deletions merlin/io/dataset_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def _path_partition_map(self):
def num_rows(self):
raise NotImplementedError(""" Returns the number of rows in the dataset """)

def validate_dataset(self, **kwargs):
raise NotImplementedError(""" Returns True if the raw data is efficient for NVTabular """)

@classmethod
def regenerate_dataset(cls, dataset, output_path, columns=None, **kwargs):
raise NotImplementedError(""" Regenerate a dataset with optimal properties """)

def sample_data(self, n=1):
"""Return a sample of real data from the dataset
Expand Down
Loading

0 comments on commit 2053315

Please sign in to comment.