diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index efba7b3938..77c8818442 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -29,7 +29,12 @@ FormatTextGenerationSFT, ) from distilabel.steps.generators.data import LoadDataFromDicts -from distilabel.steps.generators.huggingface import LoadHubDataset +from distilabel.steps.generators.huggingface import ( + LoadDataFromDisk, + LoadDataFromFileSystem, + LoadDataFromHub, + LoadHubDataset, +) from distilabel.steps.globals.huggingface import PushToHub from distilabel.steps.keep import KeepColumns from distilabel.steps.typing import GeneratorStepOutput, StepOutput @@ -49,6 +54,9 @@ "GlobalStep", "KeepColumns", "LoadDataFromDicts", + "LoadDataFromDisk", + "LoadDataFromFileSystem", + "LoadDataFromHub", "LoadHubDataset", "PushToHub", "Step", diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index da4a6d7f52..96bbbb4882 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -12,15 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union - -import requests -from datasets import DatasetInfo, IterableDataset, load_dataset +import warnings +from collections import defaultdict +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +from datasets import ( + Dataset, + DatasetInfo, + IterableDataset, + get_dataset_infos, + load_dataset, + load_from_disk, +) from pydantic import Field, PrivateAttr -from requests.exceptions import ConnectionError +from upath import UPath +from distilabel.distiset import Distiset from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.base import GeneratorStep @@ -28,7 +47,7 @@ from distilabel.steps.typing import GeneratorStepOutput -class LoadHubDataset(GeneratorStep): +class LoadDataFromHub(GeneratorStep): """Loads a dataset from the Hugging Face Hub. `GeneratorStep` that loads a dataset from the Hugging Face Hub using the `datasets` @@ -50,6 +69,8 @@ class LoadHubDataset(GeneratorStep): `False`. - `num_examples`: The number of examples to load from the dataset. By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. Output columns: - dynamic (`all`): The columns that will be generated by this step, based on the @@ -80,8 +101,12 @@ class LoadHubDataset(GeneratorStep): default=None, description="The number of examples to load from the dataset. By default will load all examples.", ) + storage_options: Optional[Dict[str, Any]] = Field( + default=None, + description="The storage options to use when loading the dataset.", + ) - _dataset: Union[IterableDataset, None] = PrivateAttr(...) + _dataset: Union[IterableDataset, Dataset, None] = PrivateAttr(...) def load(self) -> None: """Load the dataset from the Hugging Face Hub""" @@ -155,11 +180,11 @@ def _get_dataset_num_examples(self) -> int: Returns: The number of examples in the dataset. """ - dataset_info = self._get_dataset_info() - split = self.split - if self.config: - return dataset_info["splits"][split]["num_examples"] - return dataset_info["default"]["splits"][split]["num_examples"] + return ( + self._dataset_info[self.config if self.config else "default"] + .splits[self.split] + .num_examples + ) def _get_dataset_columns(self) -> List[str]: """Get the columns of the dataset, based on the `config` runtime parameter provided. @@ -167,18 +192,14 @@ def _get_dataset_columns(self) -> List[str]: Returns: The columns of the dataset. """ - dataset_info = self._get_dataset_info() - - if isinstance(dataset_info, DatasetInfo): - if self.config: - return list(self._dataset[self.config].info.features.keys()) - return list(self._dataset.info.features.keys()) - - if self.config: - return list(dataset_info["features"].keys()) - return list(dataset_info["default"]["features"].keys()) + return list( + self._dataset_info[ + self.config if self.config else "default" + ].features.keys() + ) - def _get_dataset_info(self) -> Dict[str, Any]: + @cached_property + def _dataset_info(self) -> Dict[str, DatasetInfo]: """Calls the Datasets Server API from Hugging Face to obtain the dataset information. Returns: @@ -188,47 +209,256 @@ def _get_dataset_info(self) -> Dict[str, Any]: config = self.config try: - return _get_hf_dataset_info(repo_id, config) - except ConnectionError: + return get_dataset_infos(repo_id) + except Exception as e: # The previous could fail in case of a internet connection issues. # Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway. - self.load() + self._logger.warning( + f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}" + ) + ds = load_dataset(repo_id, config=self.config, split=self.split) if config: - return self._dataset[config].info - return self._dataset.info + return ds[config].info + return ds.info + + +class LoadHubDataset(LoadDataFromHub): + def __init__(self, **data: Any) -> None: + warnings.warn( + "`LoadHubDataset` is deprecated and will be removed in version 1.3.0, use `LoadFromHub` instead.", + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(**data) + + +class LoadDataFromFileSystem(LoadDataFromHub): + """Loads a dataset from a file in your filesystem. + `GeneratorStep` that creates a dataset from a file in the filesystem, uses Hugging Face `datasets` + library. Take a look at [Hugging Face Datasets](https://huggingface.co/docs/datasets/loading) + for more information of the supported file types. -@lru_cache -def _get_hf_dataset_info( - repo_id: str, config: Union[str, None] = None -) -> Dict[str, Any]: - """Calls the Datasets Server API from Hugging Face to obtain the dataset information. - The results are cached to avoid making multiple requests to the server. + Attributes: + data_files: The path to the file, or directory containing the files that conform + the dataset. + split: The split of the dataset to load (typically will be `train`, `test` or `validation`). - Args: - repo_id: The Hugging Face Hub repository ID of the dataset. - config: The configuration of the dataset. This is optional and only needed if the - dataset has multiple configurations. + Runtime parameters: + - `batch_size`: The batch size to use when processing the data. + - `data_files`: The path to the file, or directory containing the files that conform + the dataset. + - `split`: The split of the dataset to load. Defaults to 'train'. + - `streaming`: Whether to load the dataset in streaming mode or not. Defaults to + `False`. + - `num_examples`: The number of examples to load from the dataset. + By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + - `filetype`: The expected filetype. If not provided, it will be inferred from the file extension. + For more than one file, it will be inferred from the first file. + + Output columns: + - dynamic (`all`): The columns that will be generated by this step, based on the + datasets loaded from the Hugging Face Hub. - Returns: - The dataset information. + Categories: + - load """ - params = {"dataset": repo_id} - if config is not None: - params["config"] = config + data_files: RuntimeParameter[Union[str, Path]] = Field( + default=None, + description="The data files, or directory containing the data files, to generate the dataset from.", + ) + filetype: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The expected filetype. If not provided, it will be inferred from the file extension.", + ) + + def load(self) -> None: + """Load the dataset from the file/s in disk.""" + super(GeneratorStep, self).load() + + data_path = UPath(self.data_files, storage_options=self.storage_options) + + (data_files, self.filetype) = self._prepare_data_files(data_path) + + self._dataset = load_dataset( + self.filetype, + data_files=data_files, + split=self.split, + streaming=self.streaming, + storage_options=self.storage_options, + ) + + if not self.streaming and self.num_examples: + self._dataset = self._dataset.select(range(self.num_examples)) + if not self.num_examples: + if self.streaming: + # There's no better way to get the number of examples in a streaming dataset, + # load it again for the moment. + self.num_examples = len( + load_dataset( + self.filetype, data_files=self.data_files, split=self.split + ) + ) + else: + self.num_examples = len(self._dataset) + + @staticmethod + def _prepare_data_files( + data_path: UPath, + ) -> Tuple[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], str]: + """Prepare the loading process by setting the `data_files` attribute. - if "HF_TOKEN" in os.environ: - headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"} - else: - headers = None + Args: + data_path: The path to the data files, or directory containing the data files. + + Returns: + Tuple with the data files and the filetype. + """ - response = requests.get( - "https://datasets-server.huggingface.co/info", params=params, headers=headers + def get_filetype(data_path: UPath) -> str: + filetype = data_path.suffix.lstrip(".") + if filetype == "jsonl": + filetype = "json" + return filetype + + if data_path.is_file(): + filetype = get_filetype(data_path) + data_files = str(data_path) + elif data_path.is_dir(): + file_sequence = [] + file_map = defaultdict(list) + for file_or_folder in data_path.iterdir(): + if file_or_folder.is_file(): + file_sequence.append(str(file_or_folder)) + elif file_or_folder.is_dir(): + for file in file_or_folder.iterdir(): + file_sequence.append(str(file)) + file_map[str(file_or_folder)].append(str(file)) + + data_files = file_sequence or file_map + # Try to obtain the filetype from any of the files, assuming all files have the same type. + if file_sequence: + filetype = get_filetype(UPath(file_sequence[0])) + else: + filetype = get_filetype(UPath(file_map[list(file_map.keys())[0]][0])) + return data_files, filetype + + @property + def outputs(self) -> List[str]: + """The columns that will be generated by this step, based on the datasets from a file + in disk. + + Returns: + The columns that will be generated by this step. + """ + # We assume there are Dataset/IterableDataset, not it's ...Dict counterparts + if self._dataset is Ellipsis: + raise ValueError( + "Dataset not loaded yet, you must call `load` method first." + ) + + return self._dataset.column_names + + +class LoadDataFromDisk(LoadDataFromHub): + """Load a dataset that was previously saved to disk. + + If you previously saved your dataset using the `save_to_disk` method, or + `Distiset.save_to_disk` you can load it again to build a new pipeline using this class. + + Attributes: + dataset_path: The path to the dataset or distiset. + split: The split of the dataset to load (typically will be `train`, `test` or `validation`). + config: The configuration of the dataset to load. This is optional and only needed + if the dataset has multiple configurations. + + Runtime parameters: + - `batch_size`: The batch size to use when processing the data. + - `dataset_path`: The path to the dataset or distiset. + - `is_distiset`: Whether the dataset to load is a `Distiset` or not. Defaults to False. + - `split`: The split of the dataset to load. Defaults to 'train'. + - `config`: The configuration of the dataset to load. This is optional and only + needed if the dataset has multiple configurations. + - `num_examples`: The number of examples to load from the dataset. + By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + + Output columns: + - dynamic (`all`): The columns that will be generated by this step, based on the + datasets loaded from the Hugging Face Hub. + + Categories: + - load + """ + + dataset_path: RuntimeParameter[Union[str, Path]] = Field( + default=None, + description="_summary_", + ) + config: RuntimeParameter[str] = Field( + default=None, + description="The configuration of the dataset to load. This is optional and only" + " needed if the dataset has multiple configurations.", + ) + is_distiset: Optional[RuntimeParameter[bool]] = Field( + default=False, + description="Whether the dataset to load is a `Distiset` or not. Defaults to False.", + ) + keep_in_memory: Optional[RuntimeParameter[bool]] = Field( + default=None, + description="Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk` " + " for more information. Defaults to `None`.", + ) + split: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The split of the dataset to load. By default will load the whole Dataset/Distiset.", ) - assert ( - response.status_code == 200 - ), f"Failed to get '{repo_id}' dataset info. Make sure you have set the HF_TOKEN environment variable if it is a private dataset." + def load(self) -> None: + """Load the dataset from the file/s in disk.""" + super(GeneratorStep, self).load() + if self.is_distiset: + ds = Distiset.load_from_disk( + self.dataset_path, + keep_in_memory=self.keep_in_memory, + storage_options=self.storage_options, + ) + if self.config: + ds = ds[self.config] + + else: + ds = load_from_disk( + self.dataset_path, + keep_in_memory=self.keep_in_memory, + storage_options=self.storage_options, + ) + + if self.split: + ds = ds[self.split] + + self._dataset = ds + + if self.num_examples: + self._dataset = self._dataset.select(range(self.num_examples)) + else: + self.num_examples = len(self._dataset) + + @property + def outputs(self) -> List[str]: + """The columns that will be generated by this step, based on the datasets from a file + in disk. + + Returns: + The columns that will be generated by this step. + """ + # We assume there are Dataset/IterableDataset, not it's ...Dict counterparts + if self._dataset is Ellipsis: + raise ValueError( + "Dataset not loaded yet, you must call `load` method first." + ) - return response.json()["dataset_info"] + return self._dataset.column_names diff --git a/tests/unit/steps/generators/sample_functions.jsonl b/tests/unit/steps/generators/sample_functions.jsonl new file mode 100644 index 0000000000..700d21ad5b --- /dev/null +++ b/tests/unit/steps/generators/sample_functions.jsonl @@ -0,0 +1,11 @@ +{"type": "function", "function": {"name": "code_interpreter", "description": "Execute the provided Python code string on the terminal using exec.\n\n The string should contain valid, executable and pure Python code in markdown syntax.\n Code should also import any required Python packages.\n\n Args:\n code_markdown (str): The Python code with markdown syntax to be executed.\n For example: ```python\n\n```\n\n Returns:\n dict | str: A dictionary containing variables declared and values returned by function calls,\n or an error message if an exception occurred.\n\n Note:\n Use this function with caution, as executing arbitrary code can pose security risks.", "parameters": {"type": "object", "properties": {"code_markdown": {"type": "string"}}, "required": ["code_markdown"]}}} +{"type": "function", "function": {"name": "google_search_and_scrape", "description": "Performs a Google search for the given query, retrieves the top search result URLs,\nand scrapes the text content and table data from those pages in parallel.\n\nArgs:\n query (str): The search query.\nReturns:\n list: A list of dictionaries containing the URL, text content, and table data for each scraped page.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "get_current_stock_price", "description": "Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_company_news", "description": "Get company news and press releases for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_company_profile", "description": "Get company profile and overview for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing company profile and overview.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_stock_fundamentals", "description": "Get fundamental data for a given stock symbol using yfinance API.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n dict: A dictionary containing fundamental data.\n Keys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_financial_statements", "description": "Get financial statements for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_key_financial_ratios", "description": "Get key financial ratios for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing key financial ratios.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_analyst_recommendations", "description": "Get analyst recommendations for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing analyst recommendations.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_dividend_data", "description": "Get dividend data for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing dividend data.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_technical_indicators", "description": "Get technical indicators for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing technical indicators.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} diff --git a/tests/unit/steps/generators/test_data.py b/tests/unit/steps/generators/test_data.py index 9817837e20..c35b9db86d 100644 --- a/tests/unit/steps/generators/test_data.py +++ b/tests/unit/steps/generators/test_data.py @@ -17,7 +17,7 @@ from pydantic import ValidationError -class TestLoadDataFromDictsTask: +class TestLoadDataFromDicts: data = [{"instruction": "test"}] * 10 def test_init(self) -> None: diff --git a/tests/unit/steps/generators/test_huggingface.py b/tests/unit/steps/generators/test_huggingface.py index 34b44f4fc5..e72a70acb2 100644 --- a/tests/unit/steps/generators/test_huggingface.py +++ b/tests/unit/steps/generators/test_huggingface.py @@ -13,19 +13,27 @@ # limitations under the License. import os +import tempfile +from pathlib import Path from typing import Generator, Union import pytest from datasets import Dataset, IterableDataset +from distilabel.distiset import Distiset from distilabel.pipeline import Pipeline -from distilabel.steps.generators.huggingface import LoadHubDataset +from distilabel.steps.generators.huggingface import ( + LoadDataFromDisk, + LoadDataFromFileSystem, + LoadDataFromHub, + LoadHubDataset, +) DISTILABEL_RUN_SLOW_TESTS = os.getenv("DISTILABEL_RUN_SLOW_TESTS", False) @pytest.fixture(scope="module") def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", repo_id="distilabel-internal-testing/instruction-dataset-mini", split="test", @@ -39,12 +47,12 @@ def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]: not DISTILABEL_RUN_SLOW_TESTS, reason="These tests depend on internet connection, are slow and depend mainly on HF API, we don't need to test them often.", ) -class TestLoadHubDataset: +class TestLoadDataFromHub: @pytest.mark.parametrize( "streaming, ds_type", [(True, IterableDataset), (False, Dataset)] ) def test_runtime_parameters(self, streaming: bool, ds_type) -> None: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", repo_id="distilabel-internal-testing/instruction-dataset-mini", split="test", @@ -60,6 +68,131 @@ def test_runtime_parameters(self, streaming: bool, ds_type) -> None: assert isinstance(generator_step_output[1], bool) assert len(generator_step_output[0]) == 2 - def test_dataset_outputs(self, dataset_loader: LoadHubDataset) -> None: + def test_dataset_outputs(self, dataset_loader: LoadDataFromHub) -> None: # TODO: This test can be run with/without internet connection, we should emulate it here with a mock. assert dataset_loader.outputs == ["prompt", "completion", "meta"] + + +class TestLoadDataFromFileSystem: + @pytest.mark.parametrize("filetype", ["json", None]) + @pytest.mark.parametrize("streaming", [True, False]) + def test_read_from_jsonl(self, streaming: bool, filetype: Union[str, None]) -> None: + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=str(Path(__file__).parent / "sample_functions.jsonl"), + streaming=streaming, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 11 + + @pytest.mark.parametrize("filetype", ["json", None]) + def test_read_from_jsonl_with_folder(self, filetype: Union[str, None]) -> None: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + filename = "sample_functions.jsonl" + sample_file = Path(__file__).parent / filename + for i in range(3): + Path(tmpdir).mkdir(parents=True, exist_ok=True) + (Path(tmpdir) / f"sample_functions_{i}.jsonl").write_text( + sample_file.read_text(), encoding="utf-8" + ) + + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=tmpdir, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 33 + + @pytest.mark.parametrize("filetype", ["json", None]) + def test_read_from_jsonl_with_nested_folder( + self, filetype: Union[str, None] + ) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + filename = "sample_functions.jsonl" + sample_file = Path(__file__).parent / filename + for folder in ["train", "validation"]: + (Path(tmpdir) / folder).mkdir(parents=True, exist_ok=True) + (Path(tmpdir) / folder / filename).write_text( + sample_file.read_text(), encoding="utf-8" + ) + + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=tmpdir, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 22 + + @pytest.mark.parametrize("load", [True, False]) + def test_outputs(self, load: bool) -> None: + loader = LoadDataFromFileSystem( + filetype="json", + data_files=str(Path(__file__).parent / "sample_functions.jsonl"), + ) + if load: + loader.load() + assert loader.outputs == ["type", "function"] + else: + with pytest.raises(ValueError): + loader.outputs # noqa: B018 + + +class TestLoadDataFromDisk: + def test_load_dataset_from_disk(self) -> None: + dataset = Dataset.from_dict({"a": [1, 2, 3]}) + with tempfile.TemporaryDirectory() as tmpdir: + dataset_path = str(Path(tmpdir) / "dataset_path") + dataset.save_to_disk(dataset_path) + + loader = LoadDataFromDisk(dataset_path=dataset_path) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 3 + + def test_load_distiset_from_disk(self) -> None: + distiset = Distiset( + { + "leaf_step_1": Dataset.from_dict({"a": [1, 2, 3]}), + "leaf_step_2": Dataset.from_dict( + {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]} + ), + } + ) + with tempfile.TemporaryDirectory() as tmpdir: + dataset_path = str(Path(tmpdir) / "dataset_path") + distiset.save_to_disk(dataset_path) + + loader = LoadDataFromDisk( + dataset_path=dataset_path, is_distiset=True, config="leaf_step_1" + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 3 + + +def test_LoadHubDataset_deprecation_warning(): + with pytest.deprecated_call(): + LoadHubDataset( + repo_id="distilabel-internal-testing/instruction-dataset-mini", + split="test", + batch_size=2, + ) + import distilabel + from packaging.version import Version + + assert Version(distilabel.__version__) <= Version("1.3.0") diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py index 9c3ff44d57..07eb3e5b63 100644 --- a/tests/unit/test_imports.py +++ b/tests/unit/test_imports.py @@ -51,6 +51,8 @@ def test_imports() -> None: GeneratorStepOutput, KeepColumns, LoadDataFromDicts, + LoadDataFromHub, + LoadDataFromDisk, LoadHubDataset, PushToHub, Step,