diff --git a/src/distilabel/dataset.py b/src/distilabel/dataset.py index a886455c67..ea21c9aa1f 100644 --- a/src/distilabel/dataset.py +++ b/src/distilabel/dataset.py @@ -197,12 +197,13 @@ def load_from_disk(cls, dataset_path: os.PathLike, **kwargs: Any): """Load a CustomDataset from disk, also reading the task. Args: - dataset_path: Path to the dataset, as you would do with a standard Dataset. + dataset_path (os.PathLike): Path to the dataset. + kwargs (Any): Keyword arguments passed to Dataset.load_from_disk. Returns: The loaded dataset. """ - ds = super().load_from_disk(dataset_path, *kwargs) + ds = super().load_from_disk(dataset_path, **kwargs) # Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance ds.__class__ = cls task = load_task_from_disk(dataset_path) diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py index ec00fa5ed1..a690652975 100644 --- a/src/distilabel/utils/serialization.py +++ b/src/distilabel/utils/serialization.py @@ -45,7 +45,7 @@ def load_from_dict(template: Dict[str, Any]) -> Generic[T]: return instance -def load_task_from_disk(path: Path) -> "Task": +def load_task_from_disk(path: os.PathLike) -> "Task": """Loads a task from disk. Args: @@ -54,7 +54,7 @@ def load_task_from_disk(path: Path) -> "Task": Returns: Task: The task. """ - task_path = path / TASK_FILE_NAME + task_path = Path(path) / TASK_FILE_NAME if not task_path.exists(): raise FileNotFoundError(f"The task file does not exist: {task_path}") diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a084651460..ed06eb9afb 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -73,7 +73,7 @@ def metric_strategy(): @pytest.fixture -def custom_dataset(): +def custom_dataset() -> CustomDataset: ds = CustomDataset.from_dict( { "input": ["a", "b"], @@ -195,20 +195,55 @@ def sample_preference_dataset(): @pytest.mark.usefixtures("custom_dataset") -def test_dataset_save_to_disk(custom_dataset): +@pytest.mark.parametrize( + "extra_kwargs", + [ + {}, + # Just any kwarg to test it's properly passed down to the next function + {"max_shard_size": None}, + ], +) +@pytest.mark.parametrize( + "path_transform", + [ + lambda x: x, + lambda x: Path(x), + ], +) +def test_dataset_save_to_disk( + custom_dataset: CustomDataset, path_transform: os.PathLike, extra_kwargs: Any +): with tempfile.TemporaryDirectory() as tmpdir: - ds_name = Path(tmpdir) / "dataset_folder" - custom_dataset.save_to_disk(ds_name) + ds_name = path_transform(tmpdir) + custom_dataset.save_to_disk(ds_name, **extra_kwargs) + ds_name = Path(ds_name) assert ds_name.is_dir() assert (ds_name / TASK_FILE_NAME).is_file() @pytest.mark.usefixtures("custom_dataset") -def test_dataset_load_disk(custom_dataset): +@pytest.mark.parametrize( + "extra_kwargs", + [ + {}, + # Just any kwarg to test it's properly passed down to the next function + {"keep_in_memory": True}, + ], +) +@pytest.mark.parametrize( + "path_transform", + [ + lambda x: x, + lambda x: Path(x), + ], +) +def test_dataset_load_disk( + custom_dataset: CustomDataset, path_transform: os.PathLike, extra_kwargs: Any +): with tempfile.TemporaryDirectory() as tmpdir: - ds_name = Path(tmpdir) / "dataset_folder" + ds_name = path_transform(tmpdir) custom_dataset.save_to_disk(ds_name) - ds_from_disk = CustomDataset.load_from_disk(ds_name) + ds_from_disk = CustomDataset.load_from_disk(ds_name, **extra_kwargs) assert isinstance(ds_from_disk, CustomDataset) assert isinstance(ds_from_disk.task, UltraFeedbackTask)