diff --git a/src/distilabel/dataset.py b/src/distilabel/dataset.py index f9120fbd56..4d5c5209d7 100644 --- a/src/distilabel/dataset.py +++ b/src/distilabel/dataset.py @@ -17,7 +17,6 @@ from datasets import Dataset -from distilabel.utils.argilla import infer_model_metadata_properties from distilabel.utils.imports import _ARGILLA_AVAILABLE if TYPE_CHECKING: @@ -63,16 +62,16 @@ def to_argilla(self) -> "FeedbackDataset": f"Error while converting the dataset to an Argilla `FeedbackDataset` instance: {e}" ) from e - try: - rg_dataset = infer_model_metadata_properties( - hf_dataset=self, rg_dataset=rg_dataset - ) - except Exception as e: - warnings.warn( - f"Error while adding the model metadata properties: {e}", - UserWarning, - stacklevel=2, - ) + # try: + # rg_dataset = infer_model_metadata_properties( + # hf_dataset=self, rg_dataset=rg_dataset + # ) + # except Exception as e: + # warnings.warn( + # f"Error while adding the model metadata properties: {e}", + # UserWarning, + # stacklevel=2, + # ) for dataset_row in self: if any( @@ -82,7 +81,7 @@ def to_argilla(self) -> "FeedbackDataset": continue try: rg_dataset.add_records( - self.task.to_argilla_record(dataset_row=dataset_row) # type: ignore + self.task._to_argilla_record(dataset_row=dataset_row) # type: ignore ) # type: ignore except Exception as e: warnings.warn( diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 73f11b7f56..4adb7e9b60 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -504,8 +504,12 @@ def _build_dataset( # noqa: C901 # Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance _dataset.__class__ = CustomDataset if self.generator is not None and self.labeller is None: + if self.generator.task.__type__ != "generation": # type: ignore + self.generator.task.__type__ = "generation" # type: ignore _dataset.task = self.generator.task # type: ignore elif self.labeller is not None: + if self.labeller.task.__type__ != "labelling": # type: ignore + self.labeller.task.__type__ = "labelling" # type: ignore _dataset.task = self.labeller.task # type: ignore return _dataset # type: ignore diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 64b8a6e647..48f1c7f9d1 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -20,7 +20,7 @@ import importlib.resources as importlib_resources from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union from jinja2 import Template @@ -53,6 +53,7 @@ class Task(ABC): task_description: Union[str, None] = None __jinja2_template__: Union[str, None] = None + __type__: Union[Literal["generation", "labelling"], None] = None def __rich_repr__(self) -> Generator[Any, None, None]: yield "system_prompt", self.system_prompt @@ -118,3 +119,75 @@ def to_argilla_record( "`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla" " `FeedbackDataset` you will need to implement this method first." ) + + # Renamed to _to_argilla_record instead of renaming `to_argilla_record` to protected, as that would + # imply more breaking changes. + def _to_argilla_record( # noqa: C901 + self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any + ) -> Union["FeedbackRecord", List["FeedbackRecord"]]: + column_names = list(dataset_row.keys()) + if self.__type__ is None or self.__type__ == "generation": + required_column_names = self.input_args_names + self.output_args_names + elif self.__type__ == "labelling": + required_column_names = self.output_args_names + else: + raise ValueError("The task type is not supported.") + + dataset_rows = [dataset_row] + if "generation_model" in dataset_row and isinstance( + dataset_row["generation_model"], list + ): + generation_columns = column_names[ + column_names.index("generation_model") : column_names.index( + "labelling_model" + ) + if "labelling_model" in column_names + else None + ] + if any( + generation_column in required_column_names + for generation_column in generation_columns + ): + unwrapped_dataset_rows = [] + for row in dataset_rows: + for idx in range(len(dataset_row["generation_model"])): + unwrapped_dataset_row = {} + for key, value in row.items(): + if key in generation_columns: + unwrapped_dataset_row[key] = value[idx] + else: + unwrapped_dataset_row[key] = value + unwrapped_dataset_rows.append(unwrapped_dataset_row) + dataset_rows = unwrapped_dataset_rows + + if "labelling_model" in dataset_row and isinstance( + dataset_row["labelling_model"], list + ): + labelling_columns = column_names[column_names.index("labelling_model") :] + if any( + labelling_column in required_column_names + for labelling_column in labelling_columns + ): + unwrapped_dataset_rows = [] + for row in dataset_rows: + for idx in range(len(dataset_row["labelling_model"])): + unwrapped_dataset_row = {} + for key, value in row.items(): + if key in labelling_columns: + unwrapped_dataset_row[key] = value[idx] + else: + unwrapped_dataset_row[key] = value + unwrapped_dataset_rows.append(unwrapped_dataset_row) + dataset_rows = unwrapped_dataset_rows + + if len(dataset_rows) == 1: + return self.to_argilla_record(dataset_rows[0], *args, **kwargs) + + records = [] + for dataset_row in dataset_rows: + generated_records = self.to_argilla_record(dataset_row, *args, **kwargs) + if isinstance(generated_records, list): + records.extend(generated_records) + else: + records.append(generated_records) + return records diff --git a/src/distilabel/tasks/critique/base.py b/src/distilabel/tasks/critique/base.py index 7caabc73ba..cf3bd016b7 100644 --- a/src/distilabel/tasks/critique/base.py +++ b/src/distilabel/tasks/critique/base.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List +from typing import List, Literal from typing_extensions import TypedDict @@ -29,6 +29,8 @@ class CritiqueTask(Task): task_description (Union[str, None], optional): the description of the task. Defaults to `None`. """ + __type__: Literal["labelling"] = "labelling" + @property def input_args_names(self) -> List[str]: """Returns the names of the input arguments of the task.""" diff --git a/src/distilabel/tasks/preference/base.py b/src/distilabel/tasks/preference/base.py index 76ff436036..68b1a4c95d 100644 --- a/src/distilabel/tasks/preference/base.py +++ b/src/distilabel/tasks/preference/base.py @@ -136,7 +136,7 @@ def to_argilla_dataset( def _merge_rationales( self, rationales: List[str], generations_column: str = "generations" ) -> str: - return "".join( + return "\n".join( f"{generations_column}-{idx}:\n{rationale}\n" for idx, rationale in enumerate(rationales, start=1) ) diff --git a/src/distilabel/tasks/preference/judgelm.py b/src/distilabel/tasks/preference/judgelm.py index baa66fd256..a170c1286a 100644 --- a/src/distilabel/tasks/preference/judgelm.py +++ b/src/distilabel/tasks/preference/judgelm.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import ClassVar, List, TypedDict +from typing import ClassVar, List, Literal, TypedDict from distilabel.tasks.base import get_template from distilabel.tasks.preference.base import PreferenceTask @@ -38,8 +38,6 @@ class JudgeLMTask(PreferenceTask): task_description (Union[str, None], optional): the description of the task. Defaults to `None`. """ - __jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE - task_description: str = ( "We would like to request your feedback on the performance of {num_responses} AI assistants in response to the" " user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details" @@ -52,6 +50,9 @@ class JudgeLMTask(PreferenceTask): ) system_prompt: str = "You are a helpful and precise assistant for checking the quality of the answer." + __jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE + __type__: Literal["labelling"] = "labelling" + def generate_prompt(self, input: str, generations: List[str]) -> Prompt: """Generates a prompt following the JudgeLM specification. diff --git a/src/distilabel/tasks/preference/ultrafeedback.py b/src/distilabel/tasks/preference/ultrafeedback.py index 0bba9f2ef2..2cd7c485ea 100644 --- a/src/distilabel/tasks/preference/ultrafeedback.py +++ b/src/distilabel/tasks/preference/ultrafeedback.py @@ -14,7 +14,16 @@ from dataclasses import dataclass, field from textwrap import dedent -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, TypedDict +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Literal, + Optional, + TypedDict, +) from distilabel.tasks.base import get_template from distilabel.tasks.preference.base import PreferenceTask @@ -53,6 +62,10 @@ class UltraFeedbackTask(PreferenceTask): ratings: List[Rating] task_description: str + system_prompt: ( + str + ) = "Your role is to evaluate text quality based on given criteria." + __jinja2_template__: ClassVar[str] = field( default=_ULTRAFEEDBACK_TEMPLATE, init=False, repr=False ) @@ -63,10 +76,7 @@ class UltraFeedbackTask(PreferenceTask): "honesty", "instruction-following", ] - - system_prompt: ( - str - ) = "Your role is to evaluate text quality based on given criteria." + __type__: Literal["labelling"] = "labelling" def generate_prompt(self, input: str, generations: List[str]) -> Prompt: """Generates a prompt following the ULTRAFEEDBACK specification. diff --git a/src/distilabel/tasks/preference/ultrajudge.py b/src/distilabel/tasks/preference/ultrajudge.py index 421fa57333..e2fdbf751f 100644 --- a/src/distilabel/tasks/preference/ultrajudge.py +++ b/src/distilabel/tasks/preference/ultrajudge.py @@ -14,7 +14,7 @@ import re from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, List, TypedDict +from typing import Any, ClassVar, Dict, List, Literal, TypedDict from distilabel.tasks.base import Prompt, get_template from distilabel.tasks.preference.base import PreferenceTask @@ -79,6 +79,7 @@ class UltraJudgeTask(PreferenceTask): __jinja2_template__: ClassVar[str] = field( default=_ULTRAJUDGE_TEMPLATE, init=False, repr=False ) + __type__: Literal["labelling"] = "labelling" @property def output_args_names(self) -> List[str]: diff --git a/src/distilabel/tasks/text_generation/base.py b/src/distilabel/tasks/text_generation/base.py index ea65eb960e..e07e22e2f0 100644 --- a/src/distilabel/tasks/text_generation/base.py +++ b/src/distilabel/tasks/text_generation/base.py @@ -70,6 +70,8 @@ class TextGenerationTask(Task): ) principles_distribution: Union[Dict[str, float], Literal["balanced"], None] = None + __type__: Literal["generation"] = "generation" + def __post_init__(self) -> None: """Validates the `principles_distribution` if it is a dict.