Skip to content

Commit

Permalink
feat: 230 feature add sentence transformers support for the to argill…
Browse files Browse the repository at this point in the history
…a method (#262)

* chore: initial outline adding vectors

* chore: added sentence-transformers to extras

* chore: aded `add_vectors_to_argilla_dataset` method to generation tasks

* chore: updated to_argilla methods with call to `Task.add_vectors_to_argilla_dataset`

* chore: updated warning

* chore: resolved linting issues

* Update pyproject.toml

* chore: move logic to `CustomDataset` class

* tests: added basic tests for `to_argilla` method

* docs: updated references to `vector_strategy`

* chore: reformatted tests

* chore: remove cache step

* chore: limit to using 5 fields and defaulting to the first 5

* tests: resolved failing tests

* tests: remove faulty type hint

* chore: configered argilla version

* chore: processed comments code reveiw

* chore: updated field selected based on column names

* chore: added a extra check for dataset_columns is False or []
  • Loading branch information
davidberenstein1957 authored Jan 24, 2024
1 parent 0513bba commit 698b556
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 15 deletions.
7 changes: 7 additions & 0 deletions docs/snippets/technical-reference/pipeline/argilla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import argilla as rg
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)

rg.init(api_key="<YOUR_ARGILLA_API_KEY>", api_url="<YOUR_ARGILLA_API_URL>")

rg_dataset = pipe_dataset.to_argilla()
rg_dataset.push_to_argilla(name="preference-dataset", workspace="admin")

# with a custom `vector_strategy``
vector_strategy = SentenceTransformersExtractor(model="TaylorAI/bge-micro-v2")
rg_dataset = pipe_dataset.to_argilla(vector_strategy=vector_strategy)
11 changes: 7 additions & 4 deletions docs/technical-reference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ We will use this `LLMPool` as the generator for our pipeline and we will use GPT
--8<-- "docs/snippets/technical-reference/pipeline/pipeline_llmpool_processllm_4.py"
```

1. We also will execute the calls to OpenAI API in a different process using the `ProcessLLM`. This will allow to not block the main process GIL, and allowing the generator to continue with the next batch.
1. We also will execute the calls to OpenAI API in a different process using the `ProcessLLM`. This will allow to not block the main process GIL, and allowing the generator to continue with the next batch.

Then, we will load the dataset and call the `generate` method of the pipeline. For each input in the dataset, the `LLMPool` will randomly select two `LLM`s and will generate two generations for each of them. The generations will be labelled by GPT-4 using the `UltraFeedbackTask` for instruction-following. Finally, we will push the generated dataset to Argilla, in order to review the generations and labels that were automatically generated, and to manually correct them if needed.

Expand Down Expand Up @@ -167,7 +167,10 @@ The API reference can be found here: [pipeline][distilabel.pipeline.pipeline]

## Argilla integration

The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI models may require some additional human processing. To facilitate human feedback, the dataset can be uploaded to [`Argilla`](https://github.com/argilla-io/argilla). This process involves logging into an [`Argilla`](https://docs.argilla.io/en/latest/getting_started/cheatsheet.html#connect-to-argilla) instance, converting the dataset to the required format using `CustomDataset.to_argilla()`, and subsequently using `push_to_argilla` on the resulting dataset:
The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI models may require some additional human processing. To facilitate human feedback, the dataset can be uploaded to [`Argilla`](https://github.com/argilla-io/argilla). This process involves logging into an [`Argilla`](https://docs.argilla.io/en/latest/getting_started/cheatsheet.html#connect-to-argilla) instance, converting the dataset to the required format using `CustomDataset.to_argilla()`, and subsequently using `push_to_argilla` on the resulting dataset. This conversion automatically adds some out-of-the-box filtering and search parameters as semantic search `vectors` and through `MetadataProperties`. These can directly be used within the Argilla UI to help you find the most relevant examples. Let's briefly introduce the parameters we may find:

- `columns_names`: The names of the columns in the dataset to be used for vectors and metadata. By default, it is set to None meaning the first 5 fields from input and output columns will be taken.
- `vector_strategy`: The strategy used to generate the semantic search vectors. By default, it is set to `True` which initializes a standard `SentenceTransformersExtractor()` that computes vectors for all fields in the dataset using `TaylorAI/bge-micro-v2`. Alternatively, you can pass a `SentenceTransformersExtractor` by importing it from `argilla.client.feedback.integrations.sentencetransformers`.

```python
--8<-- "docs/snippets/technical-reference/pipeline/argilla.py"
Expand All @@ -177,9 +180,9 @@ The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI m

The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment or instruction fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format) (initially we only cover the case for *DPO*).

`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction of *dataset binarization*.
`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction to *dataset binarization*.

By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it want's to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information.
By default, the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it wants to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information.

!!! Binarization

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ openai = ["openai >= 1.0.0"]
vllm = ["vllm >= 0.2.1"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
together = ["together"]
argilla = ["argilla >= 1.18.0"]
argilla = ["argilla > 1.21.0", "sentence-transformers >= 2.0.0"]
tests = ["pytest >= 7.4.0"]
docs = [
"mkdocs-material >= 9.5.0",
Expand Down
64 changes: 61 additions & 3 deletions src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

from datasets import Dataset

from distilabel.utils.argilla import infer_field_from_dataset_columns
from distilabel.utils.dataset import load_task_from_disk, save_task_to_disk
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)


if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)

from distilabel.tasks.base import Task

Expand All @@ -37,10 +47,21 @@ class CustomDataset(Dataset):

task: Union["Task", None] = None

def to_argilla(self) -> "FeedbackDataset":
def to_argilla(
self,
dataset_columns: List[str] = None,
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
"""Converts the dataset to an Argilla `FeedbackDataset` instance, based on the
task defined in the dataset as part of `Pipeline.generate`.
Args:
fields (List[str]): the fields to be used for the Argilla `FeedbackDataset` instance.
By default, the first 5 fields will be used.
vector_strategy (Union[bool, SentenceTransformersExtractor]): the strategy to be used for
adding vectors to the dataset. If `True`, the default `SentenceTransformersExtractor`
will be used with the `TaylorAI/bge-micro-2` model. If `False`, no vectors will be added to the dataset.
Raises:
ImportError: if the argilla library is not installed.
ValueError: if the task is not set.
Expand Down Expand Up @@ -93,8 +114,45 @@ def to_argilla(self) -> "FeedbackDataset":
UserWarning,
stacklevel=2,
)

selected_fields = infer_field_from_dataset_columns(
dataset_columns=dataset_columns, dataset=rg_dataset, task=self.task
)

rg_dataset = self.add_vectors_to_argilla_dataset(
dataset=rg_dataset, vector_strategy=vector_strategy, fields=selected_fields
)

return rg_dataset

def add_vectors_to_argilla_dataset(
self,
dataset: Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"],
vector_strategy: Union[bool, "SentenceTransformersExtractor"],
fields: List[str] = None,
) -> Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"]:
if _ARGILLA_AVAILABLE and vector_strategy:
try:
if isinstance(vector_strategy, SentenceTransformersExtractor):
ste: SentenceTransformersExtractor = vector_strategy
elif vector_strategy:
ste = SentenceTransformersExtractor()
dataset = ste.update_dataset(dataset=dataset, fields=fields)
except Exception as e:
warnings.warn(
f"An error occurred while adding vectors to the dataset: {e}",
stacklevel=2,
)

elif not _ARGILLA_AVAILABLE and vector_strategy:
warnings.warn(
"An error occurred while adding vectors to the dataset: "
"The `argilla`/`sentence-transformers` packages are not installed or the installed version is not compatible with the"
" required version. If you want to add vectors to your dataset, please run `pip install 'distilabel[vectors]'`.",
stacklevel=2,
)
return dataset

def save_to_disk(self, dataset_path: PathLike, **kwargs: Any) -> None:
"""Saves the datataset to disk, also saving the task.
Expand Down
1 change: 0 additions & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import importlib_resources
else:
import importlib.resources as importlib_resources

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union

Expand Down
1 change: 1 addition & 0 deletions src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
if _ARGILLA_AVAILABLE:
import argilla as rg


if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord

Expand Down
1 change: 1 addition & 0 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
) # type: ignore
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.

return rg.FeedbackDataset(
fields=fields,
questions=questions, # type: ignore
Expand Down
35 changes: 35 additions & 0 deletions src/distilabel/utils/argilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Dict, List

from distilabel.utils.imports import _ARGILLA_AVAILABLE
Expand All @@ -24,6 +25,40 @@
from argilla.client.feedback.schemas.types import AllowedFieldTypes
from datasets import Dataset

from distilabel.tasks.base import Task


def infer_field_from_dataset_columns(
task: "Task", dataset: "FeedbackDataset", dataset_columns: List[str] = None
) -> List[str]:
# set columns to all input and output columns for the task
if dataset_columns is None:
dataset_columns = getattr(task, "input_args_names", []) + getattr(
task, "output_args_names", []
)
elif dataset_columns is False or len(dataset_columns) == 0:
dataset_columns = []

# get the first 5 that align with column selection + f"{column_name}_idx"
selected_fields = []
optional_fields = [field.name for field in dataset.fields]
for column in dataset_columns:
selected_fields += [field for field in optional_fields if column in field]

selected_fields = list(dict.fromkeys(selected_fields))
if len(selected_fields) > 5:
selected_fields = selected_fields[:5]
warnings.warn(
f"More than 5 fields found from {optional_fields}, only the first 5 will be used: {selected_fields} for vectors.",
stacklevel=2,
)
elif len(selected_fields) == 0:
raise ValueError(
f"No fields found from {optional_fields} for vectors, please check your dataset and task configuration."
)

return selected_fields


def infer_fields_from_dataset_row(
field_names: List[str], dataset_row: Dict[str, Any]
Expand Down
3 changes: 1 addition & 2 deletions src/distilabel/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from distilabel.dataset import CustomDataset
from distilabel.tasks.base import Task


TASK_FILE_NAME = "task.pkl"

logger = get_logger()
Expand Down Expand Up @@ -55,7 +54,7 @@ def load_task_from_disk(path: Path) -> "Task":
Returns:
Task: The task.
"""
task_path = path / "task.pkl"
task_path = path / TASK_FILE_NAME
if not task_path.exists():
raise FileNotFoundError(f"The task file does not exist: {task_path}")
with open(task_path, "rb") as f:
Expand Down
4 changes: 3 additions & 1 deletion src/distilabel/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def _check_package_is_available(


_ARGILLA_AVAILABLE = _check_package_is_available(
"argilla", min_version="1.16.0", greater_or_equal=True
"argilla", min_version="1.22.0", greater_or_equal=True
) and _check_package_is_available(
"sentence-transformers", min_version="2.0.0", greater_or_equal=True
)
_OPENAI_AVAILABLE = _check_package_is_available(
"openai", min_version="1.0.0", greater_or_equal=True
Expand Down
61 changes: 58 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,41 @@
from typing import List

import pytest
from argilla import FeedbackDataset
from distilabel.dataset import CustomDataset, DatasetCheckpoint
from distilabel.tasks import TextGenerationTask, UltraFeedbackTask
from distilabel.tasks.text_generation.self_instruct import SelfInstructTask
from distilabel.utils.dataset import prepare_dataset
from distilabel.utils.dataset import TASK_FILE_NAME, prepare_dataset


@pytest.fixture
def custom_dataset():
ds = CustomDataset.from_dict({"input": ["a", "b"], "generations": ["c", "d"]})
ds = CustomDataset.from_dict(
{
"input": ["a", "b"],
"generations": ["c", "d"],
"rating": [1, 2],
"rationale": ["e", "f"],
}
)
ds.task = UltraFeedbackTask.for_overall_quality()
return ds


@pytest.fixture
def large_custom_dataset():
ds = CustomDataset.from_dict(
{
"input": ["a", "b"],
"generations": [["c"] * 10, ["d"] * 10],
"rating": [1, 2],
"rationale": ["e", "f"],
"input_2": ["a", "b"],
"generations_2": ["c", "d"],
"rating_2": [1, 2],
"rationale_2": ["e", "f"],
}
)
ds.task = UltraFeedbackTask.for_overall_quality()
return ds

Expand Down Expand Up @@ -121,14 +147,16 @@ def sample_preference_dataset():
return ds


@pytest.mark.usefixtures("custom_dataset")
def test_dataset_save_to_disk(custom_dataset):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
custom_dataset.save_to_disk(ds_name)
assert ds_name.is_dir()
assert (ds_name / "task.pkl").is_file()
assert (ds_name / TASK_FILE_NAME).is_file()


@pytest.mark.usefixtures("custom_dataset")
def test_dataset_load_disk(custom_dataset):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
Expand All @@ -138,6 +166,7 @@ def test_dataset_load_disk(custom_dataset):
assert isinstance(ds_from_disk.task, UltraFeedbackTask)


@pytest.mark.usefixtures("custom_dataset")
@pytest.mark.parametrize(
"save_frequency, dataset_len, batch_size, expected",
[
Expand Down Expand Up @@ -170,6 +199,32 @@ def test_do_checkpoint(
assert ctr == expected == chk._total_checks


@pytest.mark.usefixtures("custom_dataset")
def test_to_argilla(custom_dataset: CustomDataset):
rg_dataset = custom_dataset.to_argilla(vector_strategy=False)
assert isinstance(rg_dataset, FeedbackDataset)
assert not rg_dataset.vectors_settings
rg_dataset = custom_dataset.to_argilla()
assert rg_dataset.vectors_settings

with pytest.raises(ValueError, match="No fields"):
custom_dataset.to_argilla(dataset_columns=["fake_column"])


@pytest.mark.usefixtures("custom_dataset")
def test_to_argilla_with_wrong_dataset_columns(custom_dataset: CustomDataset):
with pytest.raises(ValueError, match="No fields"):
custom_dataset.to_argilla(dataset_columns=["fake_column"])


@pytest.mark.usefixtures("custom_dataset")
def test_to_argilla_with_too_many_fields(large_custom_dataset: CustomDataset):
with pytest.warns(UserWarning, match="More than 5 fields"):
large_custom_dataset.to_argilla(
dataset_columns=large_custom_dataset.column_names
)


@pytest.mark.parametrize(
"with_generation_model",
[True],
Expand Down

0 comments on commit 698b556

Please sign in to comment.