Skip to content

Commit

Permalink
Fix exporting model name to Argilla with LLMPool (#177)
Browse files Browse the repository at this point in the history
* Add `_to_argilla_record` to handle outputs when `LLMPool`

* Add `__type__` in each `Task`

* Add missing `\n` join in `PreferenceTask`

* Skip `infer_model_metadata_properties` temporarily

Temporarily skips the `infer_model_metadata_properties` as we cannot push more than one term in the `TermsMetadataProperty` of Argilla, so that function will need a refinement to support the `LLMPool`

* Fix accessing `__type__`

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
gabrielmbmb and alvarobartt authored Dec 21, 2023
1 parent fae24ec commit 00de1f1
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 24 deletions.
23 changes: 11 additions & 12 deletions src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from datasets import Dataset

from distilabel.utils.argilla import infer_model_metadata_properties
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,16 +62,16 @@ def to_argilla(self) -> "FeedbackDataset":
f"Error while converting the dataset to an Argilla `FeedbackDataset` instance: {e}"
) from e

try:
rg_dataset = infer_model_metadata_properties(
hf_dataset=self, rg_dataset=rg_dataset
)
except Exception as e:
warnings.warn(
f"Error while adding the model metadata properties: {e}",
UserWarning,
stacklevel=2,
)
# try:
# rg_dataset = infer_model_metadata_properties(
# hf_dataset=self, rg_dataset=rg_dataset
# )
# except Exception as e:
# warnings.warn(
# f"Error while adding the model metadata properties: {e}",
# UserWarning,
# stacklevel=2,
# )

for dataset_row in self:
if any(
Expand All @@ -82,7 +81,7 @@ def to_argilla(self) -> "FeedbackDataset":
continue
try:
rg_dataset.add_records(
self.task.to_argilla_record(dataset_row=dataset_row) # type: ignore
self.task._to_argilla_record(dataset_row=dataset_row) # type: ignore
) # type: ignore
except Exception as e:
warnings.warn(
Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,12 @@ def _build_dataset( # noqa: C901
# Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance
_dataset.__class__ = CustomDataset
if self.generator is not None and self.labeller is None:
if self.generator.task.__type__ != "generation": # type: ignore
self.generator.task.__type__ = "generation" # type: ignore
_dataset.task = self.generator.task # type: ignore
elif self.labeller is not None:
if self.labeller.task.__type__ != "labelling": # type: ignore
self.labeller.task.__type__ = "labelling" # type: ignore
_dataset.task = self.labeller.task # type: ignore
return _dataset # type: ignore

Expand Down
75 changes: 74 additions & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import importlib.resources as importlib_resources

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union

from jinja2 import Template

Expand Down Expand Up @@ -53,6 +53,7 @@ class Task(ABC):
task_description: Union[str, None] = None

__jinja2_template__: Union[str, None] = None
__type__: Union[Literal["generation", "labelling"], None] = None

def __rich_repr__(self) -> Generator[Any, None, None]:
yield "system_prompt", self.system_prompt
Expand Down Expand Up @@ -118,3 +119,75 @@ def to_argilla_record(
"`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla"
" `FeedbackDataset` you will need to implement this method first."
)

# Renamed to _to_argilla_record instead of renaming `to_argilla_record` to protected, as that would
# imply more breaking changes.
def _to_argilla_record( # noqa: C901
self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any
) -> Union["FeedbackRecord", List["FeedbackRecord"]]:
column_names = list(dataset_row.keys())
if self.__type__ is None or self.__type__ == "generation":
required_column_names = self.input_args_names + self.output_args_names
elif self.__type__ == "labelling":
required_column_names = self.output_args_names
else:
raise ValueError("The task type is not supported.")

dataset_rows = [dataset_row]
if "generation_model" in dataset_row and isinstance(
dataset_row["generation_model"], list
):
generation_columns = column_names[
column_names.index("generation_model") : column_names.index(
"labelling_model"
)
if "labelling_model" in column_names
else None
]
if any(
generation_column in required_column_names
for generation_column in generation_columns
):
unwrapped_dataset_rows = []
for row in dataset_rows:
for idx in range(len(dataset_row["generation_model"])):
unwrapped_dataset_row = {}
for key, value in row.items():
if key in generation_columns:
unwrapped_dataset_row[key] = value[idx]
else:
unwrapped_dataset_row[key] = value
unwrapped_dataset_rows.append(unwrapped_dataset_row)
dataset_rows = unwrapped_dataset_rows

if "labelling_model" in dataset_row and isinstance(
dataset_row["labelling_model"], list
):
labelling_columns = column_names[column_names.index("labelling_model") :]
if any(
labelling_column in required_column_names
for labelling_column in labelling_columns
):
unwrapped_dataset_rows = []
for row in dataset_rows:
for idx in range(len(dataset_row["labelling_model"])):
unwrapped_dataset_row = {}
for key, value in row.items():
if key in labelling_columns:
unwrapped_dataset_row[key] = value[idx]
else:
unwrapped_dataset_row[key] = value
unwrapped_dataset_rows.append(unwrapped_dataset_row)
dataset_rows = unwrapped_dataset_rows

if len(dataset_rows) == 1:
return self.to_argilla_record(dataset_rows[0], *args, **kwargs)

records = []
for dataset_row in dataset_rows:
generated_records = self.to_argilla_record(dataset_row, *args, **kwargs)
if isinstance(generated_records, list):
records.extend(generated_records)
else:
records.append(generated_records)
return records
4 changes: 3 additions & 1 deletion src/distilabel/tasks/critique/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import List
from typing import List, Literal

from typing_extensions import TypedDict

Expand All @@ -29,6 +29,8 @@ class CritiqueTask(Task):
task_description (Union[str, None], optional): the description of the task. Defaults to `None`.
"""

__type__: Literal["labelling"] = "labelling"

@property
def input_args_names(self) -> List[str]:
"""Returns the names of the input arguments of the task."""
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/tasks/preference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def to_argilla_dataset(
def _merge_rationales(
self, rationales: List[str], generations_column: str = "generations"
) -> str:
return "".join(
return "\n".join(
f"{generations_column}-{idx}:\n{rationale}\n"
for idx, rationale in enumerate(rationales, start=1)
)
Expand Down
7 changes: 4 additions & 3 deletions src/distilabel/tasks/preference/judgelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import ClassVar, List, TypedDict
from typing import ClassVar, List, Literal, TypedDict

from distilabel.tasks.base import get_template
from distilabel.tasks.preference.base import PreferenceTask
Expand All @@ -38,8 +38,6 @@ class JudgeLMTask(PreferenceTask):
task_description (Union[str, None], optional): the description of the task. Defaults to `None`.
"""

__jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE

task_description: str = (
"We would like to request your feedback on the performance of {num_responses} AI assistants in response to the"
" user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details"
Expand All @@ -52,6 +50,9 @@ class JudgeLMTask(PreferenceTask):
)
system_prompt: str = "You are a helpful and precise assistant for checking the quality of the answer."

__jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE
__type__: Literal["labelling"] = "labelling"

def generate_prompt(self, input: str, generations: List[str]) -> Prompt:
"""Generates a prompt following the JudgeLM specification.
Expand Down
20 changes: 15 additions & 5 deletions src/distilabel/tasks/preference/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

from dataclasses import dataclass, field
from textwrap import dedent
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, TypedDict
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Literal,
Optional,
TypedDict,
)

from distilabel.tasks.base import get_template
from distilabel.tasks.preference.base import PreferenceTask
Expand Down Expand Up @@ -53,6 +62,10 @@ class UltraFeedbackTask(PreferenceTask):
ratings: List[Rating]
task_description: str

system_prompt: (
str
) = "Your role is to evaluate text quality based on given criteria."

__jinja2_template__: ClassVar[str] = field(
default=_ULTRAFEEDBACK_TEMPLATE, init=False, repr=False
)
Expand All @@ -63,10 +76,7 @@ class UltraFeedbackTask(PreferenceTask):
"honesty",
"instruction-following",
]

system_prompt: (
str
) = "Your role is to evaluate text quality based on given criteria."
__type__: Literal["labelling"] = "labelling"

def generate_prompt(self, input: str, generations: List[str]) -> Prompt:
"""Generates a prompt following the ULTRAFEEDBACK specification.
Expand Down
3 changes: 2 additions & 1 deletion src/distilabel/tasks/preference/ultrajudge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, List, TypedDict
from typing import Any, ClassVar, Dict, List, Literal, TypedDict

from distilabel.tasks.base import Prompt, get_template
from distilabel.tasks.preference.base import PreferenceTask
Expand Down Expand Up @@ -79,6 +79,7 @@ class UltraJudgeTask(PreferenceTask):
__jinja2_template__: ClassVar[str] = field(
default=_ULTRAJUDGE_TEMPLATE, init=False, repr=False
)
__type__: Literal["labelling"] = "labelling"

@property
def output_args_names(self) -> List[str]:
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class TextGenerationTask(Task):
)
principles_distribution: Union[Dict[str, float], Literal["balanced"], None] = None

__type__: Literal["generation"] = "generation"

def __post_init__(self) -> None:
"""Validates the `principles_distribution` if it is a dict.
Expand Down

0 comments on commit 00de1f1

Please sign in to comment.