Skip to content

Commit

Permalink
Fix CustomDataset.load_from_disk with str/Path objects (#341)
Browse files Browse the repository at this point in the history
* Update docs and properly pass extra kwargs

* Force transformation to Path object

* Add extra tests to check for different values from the user
  • Loading branch information
plaguss authored Feb 9, 2024
1 parent a701e83 commit 37d08fc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,13 @@ def load_from_disk(cls, dataset_path: os.PathLike, **kwargs: Any):
"""Load a CustomDataset from disk, also reading the task.
Args:
dataset_path: Path to the dataset, as you would do with a standard Dataset.
dataset_path (os.PathLike): Path to the dataset.
kwargs (Any): Keyword arguments passed to Dataset.load_from_disk.
Returns:
The loaded dataset.
"""
ds = super().load_from_disk(dataset_path, *kwargs)
ds = super().load_from_disk(dataset_path, **kwargs)
# Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance
ds.__class__ = cls
task = load_task_from_disk(dataset_path)
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load_from_dict(template: Dict[str, Any]) -> Generic[T]:
return instance


def load_task_from_disk(path: Path) -> "Task":
def load_task_from_disk(path: os.PathLike) -> "Task":
"""Loads a task from disk.
Args:
Expand All @@ -54,7 +54,7 @@ def load_task_from_disk(path: Path) -> "Task":
Returns:
Task: The task.
"""
task_path = path / TASK_FILE_NAME
task_path = Path(path) / TASK_FILE_NAME
if not task_path.exists():
raise FileNotFoundError(f"The task file does not exist: {task_path}")

Expand Down
49 changes: 42 additions & 7 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def metric_strategy():


@pytest.fixture
def custom_dataset():
def custom_dataset() -> CustomDataset:
ds = CustomDataset.from_dict(
{
"input": ["a", "b"],
Expand Down Expand Up @@ -195,20 +195,55 @@ def sample_preference_dataset():


@pytest.mark.usefixtures("custom_dataset")
def test_dataset_save_to_disk(custom_dataset):
@pytest.mark.parametrize(
"extra_kwargs",
[
{},
# Just any kwarg to test it's properly passed down to the next function
{"max_shard_size": None},
],
)
@pytest.mark.parametrize(
"path_transform",
[
lambda x: x,
lambda x: Path(x),
],
)
def test_dataset_save_to_disk(
custom_dataset: CustomDataset, path_transform: os.PathLike, extra_kwargs: Any
):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
custom_dataset.save_to_disk(ds_name)
ds_name = path_transform(tmpdir)
custom_dataset.save_to_disk(ds_name, **extra_kwargs)
ds_name = Path(ds_name)
assert ds_name.is_dir()
assert (ds_name / TASK_FILE_NAME).is_file()


@pytest.mark.usefixtures("custom_dataset")
def test_dataset_load_disk(custom_dataset):
@pytest.mark.parametrize(
"extra_kwargs",
[
{},
# Just any kwarg to test it's properly passed down to the next function
{"keep_in_memory": True},
],
)
@pytest.mark.parametrize(
"path_transform",
[
lambda x: x,
lambda x: Path(x),
],
)
def test_dataset_load_disk(
custom_dataset: CustomDataset, path_transform: os.PathLike, extra_kwargs: Any
):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
ds_name = path_transform(tmpdir)
custom_dataset.save_to_disk(ds_name)
ds_from_disk = CustomDataset.load_from_disk(ds_name)
ds_from_disk = CustomDataset.load_from_disk(ds_name, **extra_kwargs)
assert isinstance(ds_from_disk, CustomDataset)
assert isinstance(ds_from_disk.task, UltraFeedbackTask)

Expand Down

0 comments on commit 37d08fc

Please sign in to comment.