diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 675e75d166..07c3481cc8 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -15,7 +15,7 @@ import logging import re from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional from datasets import load_dataset from huggingface_hub import DatasetCardData, HfApi @@ -91,13 +91,15 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + metadata = self._extract_readme_metadata(repo_id, token) + card = DistilabelDatasetCard.from_template( card_data=DatasetCardData( - config_names=sorted(self.keys()), size_categories=size_categories_parser( max(len(dataset) for dataset in self.values()) ), tags=["synthetic", "distilabel", "rlaif"], + **metadata, ), repo_id=repo_id, sample_records=sample_records, @@ -117,6 +119,34 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: token=token, ) + def _extract_readme_metadata( + self, repo_id: str, token: Optional[str] + ) -> Dict[str, Any]: + """Extracts the metadata from the README.md file of the dataset repository. + + We have to download the previous README.md file in the repo, extract the metadata from it, + and generate a dict again to be passed thoruogh the `DatasetCardData` object. + + Args: + repo_id: The ID of the repository to push to, from the `push_to_hub` method. + token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method. + + Returns: + The metadata extracted from the README.md file of the dataset repository as a dict. + """ + import re + + import yaml + from huggingface_hub.file_download import hf_hub_download + + readme_path = Path( + hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token) + ) + # Remove the '---' from the metadata + metadata = re.findall(r"---\n(.*?)\n---", readme_path.read_text(), re.DOTALL)[0] + metadata = yaml.safe_load(metadata) + return metadata + def train_test_split( self, train_size: float,