From 698b556d9e8876e7f9f673bf6ca7537bed4c1661 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Wed, 24 Jan 2024 12:17:09 +0100 Subject: [PATCH] feat: 230 feature add sentence transformers support for the to argilla method (#262) * chore: initial outline adding vectors * chore: added sentence-transformers to extras * chore: aded `add_vectors_to_argilla_dataset` method to generation tasks * chore: updated to_argilla methods with call to `Task.add_vectors_to_argilla_dataset` * chore: updated warning * chore: resolved linting issues * Update pyproject.toml * chore: move logic to `CustomDataset` class * tests: added basic tests for `to_argilla` method * docs: updated references to `vector_strategy` * chore: reformatted tests * chore: remove cache step * chore: limit to using 5 fields and defaulting to the first 5 * tests: resolved failing tests * tests: remove faulty type hint * chore: configered argilla version * chore: processed comments code reveiw * chore: updated field selected based on column names * chore: added a extra check for dataset_columns is False or [] --- .../technical-reference/pipeline/argilla.py | 7 ++ docs/technical-reference/pipeline.md | 11 ++-- pyproject.toml | 2 +- src/distilabel/dataset.py | 64 ++++++++++++++++++- src/distilabel/tasks/base.py | 1 - src/distilabel/tasks/text_generation/base.py | 1 + .../tasks/text_generation/self_instruct.py | 1 + src/distilabel/utils/argilla.py | 35 ++++++++++ src/distilabel/utils/dataset.py | 3 +- src/distilabel/utils/imports.py | 4 +- tests/test_dataset.py | 61 +++++++++++++++++- 11 files changed, 175 insertions(+), 15 deletions(-) diff --git a/docs/snippets/technical-reference/pipeline/argilla.py b/docs/snippets/technical-reference/pipeline/argilla.py index 347d40c03c..ab9df9b161 100644 --- a/docs/snippets/technical-reference/pipeline/argilla.py +++ b/docs/snippets/technical-reference/pipeline/argilla.py @@ -1,6 +1,13 @@ import argilla as rg +from argilla.client.feedback.integrations.sentencetransformers import ( + SentenceTransformersExtractor, +) rg.init(api_key="", api_url="") rg_dataset = pipe_dataset.to_argilla() rg_dataset.push_to_argilla(name="preference-dataset", workspace="admin") + +# with a custom `vector_strategy`` +vector_strategy = SentenceTransformersExtractor(model="TaylorAI/bge-micro-v2") +rg_dataset = pipe_dataset.to_argilla(vector_strategy=vector_strategy) diff --git a/docs/technical-reference/pipeline.md b/docs/technical-reference/pipeline.md index ed0acf6728..7dc14a07b1 100644 --- a/docs/technical-reference/pipeline.md +++ b/docs/technical-reference/pipeline.md @@ -109,7 +109,7 @@ We will use this `LLMPool` as the generator for our pipeline and we will use GPT --8<-- "docs/snippets/technical-reference/pipeline/pipeline_llmpool_processllm_4.py" ``` -1. We also will execute the calls to OpenAI API in a different process using the `ProcessLLM`. This will allow to not block the main process GIL, and allowing the generator to continue with the next batch. +1. We also will execute the calls to OpenAI API in a different process using the `ProcessLLM`. This will allow to not block the main process GIL, and allowing the generator to continue with the next batch. Then, we will load the dataset and call the `generate` method of the pipeline. For each input in the dataset, the `LLMPool` will randomly select two `LLM`s and will generate two generations for each of them. The generations will be labelled by GPT-4 using the `UltraFeedbackTask` for instruction-following. Finally, we will push the generated dataset to Argilla, in order to review the generations and labels that were automatically generated, and to manually correct them if needed. @@ -167,7 +167,10 @@ The API reference can be found here: [pipeline][distilabel.pipeline.pipeline] ## Argilla integration -The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI models may require some additional human processing. To facilitate human feedback, the dataset can be uploaded to [`Argilla`](https://github.com/argilla-io/argilla). This process involves logging into an [`Argilla`](https://docs.argilla.io/en/latest/getting_started/cheatsheet.html#connect-to-argilla) instance, converting the dataset to the required format using `CustomDataset.to_argilla()`, and subsequently using `push_to_argilla` on the resulting dataset: +The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI models may require some additional human processing. To facilitate human feedback, the dataset can be uploaded to [`Argilla`](https://github.com/argilla-io/argilla). This process involves logging into an [`Argilla`](https://docs.argilla.io/en/latest/getting_started/cheatsheet.html#connect-to-argilla) instance, converting the dataset to the required format using `CustomDataset.to_argilla()`, and subsequently using `push_to_argilla` on the resulting dataset. This conversion automatically adds some out-of-the-box filtering and search parameters as semantic search `vectors` and through `MetadataProperties`. These can directly be used within the Argilla UI to help you find the most relevant examples. Let's briefly introduce the parameters we may find: + +- `columns_names`: The names of the columns in the dataset to be used for vectors and metadata. By default, it is set to None meaning the first 5 fields from input and output columns will be taken. +- `vector_strategy`: The strategy used to generate the semantic search vectors. By default, it is set to `True` which initializes a standard `SentenceTransformersExtractor()` that computes vectors for all fields in the dataset using `TaylorAI/bge-micro-v2`. Alternatively, you can pass a `SentenceTransformersExtractor` by importing it from `argilla.client.feedback.integrations.sentencetransformers`. ```python --8<-- "docs/snippets/technical-reference/pipeline/argilla.py" @@ -177,9 +180,9 @@ The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI m The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment or instruction fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format) (initially we only cover the case for *DPO*). -`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction of *dataset binarization*. +`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction to *dataset binarization*. -By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it want's to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information. +By default, the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it wants to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information. !!! Binarization diff --git a/pyproject.toml b/pyproject.toml index 5883b275da..981da89b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ openai = ["openai >= 1.0.0"] vllm = ["vllm >= 0.2.1"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] together = ["together"] -argilla = ["argilla >= 1.18.0"] +argilla = ["argilla > 1.21.0", "sentence-transformers >= 2.0.0"] tests = ["pytest >= 7.4.0"] docs = [ "mkdocs-material >= 9.5.0", diff --git a/src/distilabel/dataset.py b/src/distilabel/dataset.py index fd53d646df..5f1af481ed 100644 --- a/src/distilabel/dataset.py +++ b/src/distilabel/dataset.py @@ -16,15 +16,25 @@ from dataclasses import dataclass, field from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import Dataset +from distilabel.utils.argilla import infer_field_from_dataset_columns from distilabel.utils.dataset import load_task_from_disk, save_task_to_disk from distilabel.utils.imports import _ARGILLA_AVAILABLE +if _ARGILLA_AVAILABLE: + from argilla.client.feedback.integrations.sentencetransformers import ( + SentenceTransformersExtractor, + ) + + if TYPE_CHECKING: - from argilla import FeedbackDataset + from argilla import FeedbackDataset, FeedbackRecord + from argilla.client.feedback.integrations.sentencetransformers import ( + SentenceTransformersExtractor, + ) from distilabel.tasks.base import Task @@ -37,10 +47,21 @@ class CustomDataset(Dataset): task: Union["Task", None] = None - def to_argilla(self) -> "FeedbackDataset": + def to_argilla( + self, + dataset_columns: List[str] = None, + vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True, + ) -> "FeedbackDataset": """Converts the dataset to an Argilla `FeedbackDataset` instance, based on the task defined in the dataset as part of `Pipeline.generate`. + Args: + fields (List[str]): the fields to be used for the Argilla `FeedbackDataset` instance. + By default, the first 5 fields will be used. + vector_strategy (Union[bool, SentenceTransformersExtractor]): the strategy to be used for + adding vectors to the dataset. If `True`, the default `SentenceTransformersExtractor` + will be used with the `TaylorAI/bge-micro-2` model. If `False`, no vectors will be added to the dataset. + Raises: ImportError: if the argilla library is not installed. ValueError: if the task is not set. @@ -93,8 +114,45 @@ def to_argilla(self) -> "FeedbackDataset": UserWarning, stacklevel=2, ) + + selected_fields = infer_field_from_dataset_columns( + dataset_columns=dataset_columns, dataset=rg_dataset, task=self.task + ) + + rg_dataset = self.add_vectors_to_argilla_dataset( + dataset=rg_dataset, vector_strategy=vector_strategy, fields=selected_fields + ) + return rg_dataset + def add_vectors_to_argilla_dataset( + self, + dataset: Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"], + vector_strategy: Union[bool, "SentenceTransformersExtractor"], + fields: List[str] = None, + ) -> Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"]: + if _ARGILLA_AVAILABLE and vector_strategy: + try: + if isinstance(vector_strategy, SentenceTransformersExtractor): + ste: SentenceTransformersExtractor = vector_strategy + elif vector_strategy: + ste = SentenceTransformersExtractor() + dataset = ste.update_dataset(dataset=dataset, fields=fields) + except Exception as e: + warnings.warn( + f"An error occurred while adding vectors to the dataset: {e}", + stacklevel=2, + ) + + elif not _ARGILLA_AVAILABLE and vector_strategy: + warnings.warn( + "An error occurred while adding vectors to the dataset: " + "The `argilla`/`sentence-transformers` packages are not installed or the installed version is not compatible with the" + " required version. If you want to add vectors to your dataset, please run `pip install 'distilabel[vectors]'`.", + stacklevel=2, + ) + return dataset + def save_to_disk(self, dataset_path: PathLike, **kwargs: Any) -> None: """Saves the datataset to disk, also saving the task. diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 80bc5c279c..1c102c6b00 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -18,7 +18,6 @@ import importlib_resources else: import importlib.resources as importlib_resources - from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union diff --git a/src/distilabel/tasks/text_generation/base.py b/src/distilabel/tasks/text_generation/base.py index 3c94aa7724..9ad5595b50 100644 --- a/src/distilabel/tasks/text_generation/base.py +++ b/src/distilabel/tasks/text_generation/base.py @@ -29,6 +29,7 @@ if _ARGILLA_AVAILABLE: import argilla as rg + if TYPE_CHECKING: from argilla import FeedbackDataset, FeedbackRecord diff --git a/src/distilabel/tasks/text_generation/self_instruct.py b/src/distilabel/tasks/text_generation/self_instruct.py index 75f8b73d2e..fc9d10ebeb 100644 --- a/src/distilabel/tasks/text_generation/self_instruct.py +++ b/src/distilabel/tasks/text_generation/self_instruct.py @@ -161,6 +161,7 @@ def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset": ) # type: ignore # Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties # defined above. + return rg.FeedbackDataset( fields=fields, questions=questions, # type: ignore diff --git a/src/distilabel/utils/argilla.py b/src/distilabel/utils/argilla.py index b6eafbb515..7d735b3e08 100644 --- a/src/distilabel/utils/argilla.py +++ b/src/distilabel/utils/argilla.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import TYPE_CHECKING, Any, Dict, List from distilabel.utils.imports import _ARGILLA_AVAILABLE @@ -24,6 +25,40 @@ from argilla.client.feedback.schemas.types import AllowedFieldTypes from datasets import Dataset + from distilabel.tasks.base import Task + + +def infer_field_from_dataset_columns( + task: "Task", dataset: "FeedbackDataset", dataset_columns: List[str] = None +) -> List[str]: + # set columns to all input and output columns for the task + if dataset_columns is None: + dataset_columns = getattr(task, "input_args_names", []) + getattr( + task, "output_args_names", [] + ) + elif dataset_columns is False or len(dataset_columns) == 0: + dataset_columns = [] + + # get the first 5 that align with column selection + f"{column_name}_idx" + selected_fields = [] + optional_fields = [field.name for field in dataset.fields] + for column in dataset_columns: + selected_fields += [field for field in optional_fields if column in field] + + selected_fields = list(dict.fromkeys(selected_fields)) + if len(selected_fields) > 5: + selected_fields = selected_fields[:5] + warnings.warn( + f"More than 5 fields found from {optional_fields}, only the first 5 will be used: {selected_fields} for vectors.", + stacklevel=2, + ) + elif len(selected_fields) == 0: + raise ValueError( + f"No fields found from {optional_fields} for vectors, please check your dataset and task configuration." + ) + + return selected_fields + def infer_fields_from_dataset_row( field_names: List[str], dataset_row: Dict[str, Any] diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index 63f2e397c6..2bc44c8813 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -26,7 +26,6 @@ from distilabel.dataset import CustomDataset from distilabel.tasks.base import Task - TASK_FILE_NAME = "task.pkl" logger = get_logger() @@ -55,7 +54,7 @@ def load_task_from_disk(path: Path) -> "Task": Returns: Task: The task. """ - task_path = path / "task.pkl" + task_path = path / TASK_FILE_NAME if not task_path.exists(): raise FileNotFoundError(f"The task file does not exist: {task_path}") with open(task_path, "rb") as f: diff --git a/src/distilabel/utils/imports.py b/src/distilabel/utils/imports.py index c044bf7e95..daa154d413 100644 --- a/src/distilabel/utils/imports.py +++ b/src/distilabel/utils/imports.py @@ -91,7 +91,9 @@ def _check_package_is_available( _ARGILLA_AVAILABLE = _check_package_is_available( - "argilla", min_version="1.16.0", greater_or_equal=True + "argilla", min_version="1.22.0", greater_or_equal=True +) and _check_package_is_available( + "sentence-transformers", min_version="2.0.0", greater_or_equal=True ) _OPENAI_AVAILABLE = _check_package_is_available( "openai", min_version="1.0.0", greater_or_equal=True diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01cc164e93..c96df540fa 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -18,15 +18,41 @@ from typing import List import pytest +from argilla import FeedbackDataset from distilabel.dataset import CustomDataset, DatasetCheckpoint from distilabel.tasks import TextGenerationTask, UltraFeedbackTask from distilabel.tasks.text_generation.self_instruct import SelfInstructTask -from distilabel.utils.dataset import prepare_dataset +from distilabel.utils.dataset import TASK_FILE_NAME, prepare_dataset @pytest.fixture def custom_dataset(): - ds = CustomDataset.from_dict({"input": ["a", "b"], "generations": ["c", "d"]}) + ds = CustomDataset.from_dict( + { + "input": ["a", "b"], + "generations": ["c", "d"], + "rating": [1, 2], + "rationale": ["e", "f"], + } + ) + ds.task = UltraFeedbackTask.for_overall_quality() + return ds + + +@pytest.fixture +def large_custom_dataset(): + ds = CustomDataset.from_dict( + { + "input": ["a", "b"], + "generations": [["c"] * 10, ["d"] * 10], + "rating": [1, 2], + "rationale": ["e", "f"], + "input_2": ["a", "b"], + "generations_2": ["c", "d"], + "rating_2": [1, 2], + "rationale_2": ["e", "f"], + } + ) ds.task = UltraFeedbackTask.for_overall_quality() return ds @@ -121,14 +147,16 @@ def sample_preference_dataset(): return ds +@pytest.mark.usefixtures("custom_dataset") def test_dataset_save_to_disk(custom_dataset): with tempfile.TemporaryDirectory() as tmpdir: ds_name = Path(tmpdir) / "dataset_folder" custom_dataset.save_to_disk(ds_name) assert ds_name.is_dir() - assert (ds_name / "task.pkl").is_file() + assert (ds_name / TASK_FILE_NAME).is_file() +@pytest.mark.usefixtures("custom_dataset") def test_dataset_load_disk(custom_dataset): with tempfile.TemporaryDirectory() as tmpdir: ds_name = Path(tmpdir) / "dataset_folder" @@ -138,6 +166,7 @@ def test_dataset_load_disk(custom_dataset): assert isinstance(ds_from_disk.task, UltraFeedbackTask) +@pytest.mark.usefixtures("custom_dataset") @pytest.mark.parametrize( "save_frequency, dataset_len, batch_size, expected", [ @@ -170,6 +199,32 @@ def test_do_checkpoint( assert ctr == expected == chk._total_checks +@pytest.mark.usefixtures("custom_dataset") +def test_to_argilla(custom_dataset: CustomDataset): + rg_dataset = custom_dataset.to_argilla(vector_strategy=False) + assert isinstance(rg_dataset, FeedbackDataset) + assert not rg_dataset.vectors_settings + rg_dataset = custom_dataset.to_argilla() + assert rg_dataset.vectors_settings + + with pytest.raises(ValueError, match="No fields"): + custom_dataset.to_argilla(dataset_columns=["fake_column"]) + + +@pytest.mark.usefixtures("custom_dataset") +def test_to_argilla_with_wrong_dataset_columns(custom_dataset: CustomDataset): + with pytest.raises(ValueError, match="No fields"): + custom_dataset.to_argilla(dataset_columns=["fake_column"]) + + +@pytest.mark.usefixtures("custom_dataset") +def test_to_argilla_with_too_many_fields(large_custom_dataset: CustomDataset): + with pytest.warns(UserWarning, match="More than 5 fields"): + large_custom_dataset.to_argilla( + dataset_columns=large_custom_dataset.column_names + ) + + @pytest.mark.parametrize( "with_generation_model", [True],