Skip to content

Commit

Permalink
Add {generation,labelling}_model column as metadata in Argilla (#175)
Browse files Browse the repository at this point in the history
* Add `type: ignore` in `self.generator.generate`

* Add `infer_metadata_properties`

* Import `FeedbackDataset` from `argilla`

* Extend `infer_model_metadata_properties` to return `FeedbackDataset`

* Add `model_metadata_from_dataset_row`

* Fix `infer_model_metadata_properties`
  • Loading branch information
alvarobartt authored Dec 20, 2023
1 parent 3abf5e3 commit fae24ec
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 10 deletions.
14 changes: 13 additions & 1 deletion src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from datasets import Dataset

from distilabel.utils.argilla import infer_model_metadata_properties
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if TYPE_CHECKING:
Expand Down Expand Up @@ -62,6 +63,17 @@ def to_argilla(self) -> "FeedbackDataset":
f"Error while converting the dataset to an Argilla `FeedbackDataset` instance: {e}"
) from e

try:
rg_dataset = infer_model_metadata_properties(
hf_dataset=self, rg_dataset=rg_dataset
)
except Exception as e:
warnings.warn(
f"Error while adding the model metadata properties: {e}",
UserWarning,
stacklevel=2,
)

for dataset_row in self:
if any(
dataset_row[input_arg_name] is None # type: ignore
Expand All @@ -71,7 +83,7 @@ def to_argilla(self) -> "FeedbackDataset":
try:
rg_dataset.add_records(
self.task.to_argilla_record(dataset_row=dataset_row) # type: ignore
)
) # type: ignore
except Exception as e:
warnings.warn(
f"Error while converting a row into an Argilla `FeedbackRecord` instance: {e}",
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _get_batch_generations(
Returns:
List[Dict[str, Any]]: the processed batch generations.
"""
outputs = self.generator.generate(
outputs = self.generator.generate( # type: ignore
inputs=inputs,
num_generations=num_generations,
progress_callback_func=progress_callback_func,
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from distilabel.tasks.prompt import Prompt

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord


Expand Down
11 changes: 9 additions & 2 deletions src/distilabel/tasks/preference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from distilabel.tasks.base import Task
from distilabel.utils.argilla import infer_fields_from_dataset_row
from distilabel.utils.argilla import (
infer_fields_from_dataset_row,
model_metadata_from_dataset_row,
)
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
import argilla as rg

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord


Expand Down Expand Up @@ -202,6 +205,10 @@ def to_argilla_record( # noqa: C901
metadata[f"distance-best-{ratings_column}"] = (
sorted_ratings[0] - sorted_ratings[1]
)
# Then we add the model metadata from the `generation_model` and `labelling_model`
# columns of the dataset, if they exist.
metadata.update(model_metadata_from_dataset_row(dataset_row=dataset_row))
# Finally, we return the `FeedbackRecord` with the fields and the metadata
return rg.FeedbackRecord(
fields=fields, suggestions=suggestions, metadata=metadata
)
2 changes: 1 addition & 1 deletion src/distilabel/tasks/preference/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from distilabel.tasks.prompt import Prompt

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla import FeedbackDataset

_ULTRAFEEDBACK_TEMPLATE = get_template("ultrafeedback.jinja2")

Expand Down
11 changes: 9 additions & 2 deletions src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
from distilabel.tasks.base import Task
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.principles import UltraFeedbackPrinciples
from distilabel.utils.argilla import infer_fields_from_dataset_row
from distilabel.utils.argilla import (
infer_fields_from_dataset_row,
model_metadata_from_dataset_row,
)
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
import argilla as rg

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord


Expand Down Expand Up @@ -235,4 +238,8 @@ def to_argilla_record(self, dataset_row: Dict[str, Any]) -> "FeedbackRecord":
UserWarning,
stacklevel=2,
)
# Then we add the model metadata from the `generation_model` and `labelling_model`
# columns of the dataset, if they exist.
metadata.update(model_metadata_from_dataset_row(dataset_row=dataset_row))
# Finally, we return the `FeedbackRecord` with the fields and the metadata
return rg.FeedbackRecord(fields=fields, metadata=metadata)
14 changes: 12 additions & 2 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
from distilabel.tasks.base import get_template
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask
from distilabel.utils.argilla import infer_fields_from_dataset_row
from distilabel.utils.argilla import (
infer_fields_from_dataset_row,
model_metadata_from_dataset_row,
)
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
import argilla as rg

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord

_SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2")
Expand Down Expand Up @@ -181,6 +184,13 @@ def to_argilla_record(
)
fields["instruction"] = instruction
metadata["length-instruction"] = len(instruction)

# Then we add the model metadata from the `generation_model` and `labelling_model`
# columns of the dataset, if they exist.
metadata.update(
model_metadata_from_dataset_row(dataset_row=dataset_row)
)
# Finally, we append the `FeedbackRecord` with the fields and the metadata
records.append(rg.FeedbackRecord(fields=fields, metadata=metadata))
if not records:
raise ValueError(
Expand Down
41 changes: 41 additions & 0 deletions src/distilabel/utils/argilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import argilla as rg

if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.types import AllowedFieldTypes
from datasets import Dataset


def infer_fields_from_dataset_row(
Expand All @@ -42,3 +44,42 @@ def infer_fields_from_dataset_row(
elif isinstance(dataset_row[arg_name], str):
processed_items.append(rg.TextField(name=arg_name, title=arg_name)) # type: ignore
return processed_items


def infer_model_metadata_properties(
hf_dataset: "Dataset", rg_dataset: "FeedbackDataset"
) -> "FeedbackDataset":
if not _ARGILLA_AVAILABLE:
raise ImportError(
"In order to use any of the functions defined within `utils.argilla` you must install `argilla`"
)
metadata_properties = []
for column_name in ["generation_model", "labelling_model"]:
if column_name not in hf_dataset.column_names:
continue
models = []
for item in hf_dataset[column_name]:
if isinstance(item, list):
models.extend(item)
elif isinstance(item, str):
models.append(item)
models = list(set(models))
property_name = column_name.replace("_", "-")
metadata_properties.append(
rg.TermsMetadataProperty( # type: ignore
name=property_name, title=property_name, values=models
) # type: ignore
)
if len(metadata_properties) > 0:
for metadata_property in metadata_properties:
rg_dataset.add_metadata_property(metadata_property)
return rg_dataset


def model_metadata_from_dataset_row(dataset_row: Dict[str, Any]) -> Dict[str, Any]:
metadata = {}
if "generation_model" in dataset_row:
metadata["generation-model"] = dataset_row["generation_model"]
if "labelling_model" in dataset_row:
metadata["labelling-model"] = dataset_row["labelling_model"]
return metadata

0 comments on commit fae24ec

Please sign in to comment.