From 71543961ee3c983681d3b16bbedfb5bfe1b63c18 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Wed, 29 Nov 2023 11:16:02 +0100 Subject: [PATCH] Improve `argilla` integration (#119) * Set `0.1.0` version * Delete unused `examples/label-dataset-using-judgelm.py` * Upgrade `argilla` to 1.18.0 Cannot be lower as the `metadata_properties` were included in the v1.18.0 * Improve `argilla` integration Now the methods to implemnt on each `Task` are `to_argilla_dataset` and `to_argilla_record`, so that it's easier and straight forward for the users willing to integrate Argilla within their tasks * Update `README.md` * Update `README.md` * Fixed some typos thanks to @codespell-project * Rename `responses_column` to `generations_column` * Clean `_merge_rationales` to re-use `generations_column` too * Update `README.md` * Update `README.md` * Update `README.md` * Rename `responses_values->ratings_values` * Apply suggestions from code review Co-authored-by: Gabriel Martin --------- Co-authored-by: Gabriel Martin --- README.md | 186 +++++--------- examples/label-dataset-using-judgelm.py | 50 ---- pyproject.toml | 2 +- src/distilabel/__init__.py | 2 +- src/distilabel/dataset.py | 17 +- src/distilabel/pipeline.py | 4 +- src/distilabel/tasks/argilla_utils.py | 44 ++++ src/distilabel/tasks/base.py | 119 ++------- src/distilabel/tasks/preference/base.py | 233 ++++++++++-------- .../tasks/preference/ultrafeedback.py | 15 +- src/distilabel/tasks/preference/ultrajudge.py | 24 +- 11 files changed, 281 insertions(+), 415 deletions(-) delete mode 100644 examples/label-dataset-using-judgelm.py create mode 100644 src/distilabel/tasks/argilla_utils.py diff --git a/README.md b/README.md index e8413cc700..06c244c95a 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,54 @@ -
-

⚗️ distilabel

-

- AI Feedback framework for building datasets and labelers with LLMs -

-
+
+

⚗️ distilabel

+

AI Feedback (AIF) framework for building datasets and labellers with LLMs

+
-## What's distilabel -distilabel is a framework for AI engineers to align LLM using RLHF-related methods (e.g., reward models, DPO). +![overview](https://github.com/argilla-io/distilabel/assets/36760800/360110da-809d-4e24-a29b-1a1a8bc4f9b7) + +> [!TIP] +> To discuss, get support, or give feedback [join Argilla's Slack Community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) and you will be able to engage with our amazing community and also with the core developers of `argilla` and `distilabel`. + +## What's `distilabel`? + +`distilabel` is a framework for AI engineers to align LLMs using RLHF-related methods (e.g. reward models, DPO). The initial focus is LLM fine-tuning and adaptation but we'll be extending it for predictive NLP use cases soon. Main use cases are: 1. As an AI engineer I want to **build domain-specific instruction datasets** to fine-tune OSS LLMs with increased accuracy. -2. As an AI engineer I want to **build domain-specific and diverse preference datasets** to use RLHF-related methods and align LLMs (e.g, increase the ability to follow instructions or give thruthful responses). +2. As an AI engineer I want to **build domain-specific and diverse preference datasets** to use RLHF-related methods and align LLMs (e.g, increase the ability to follow instructions or give truthful responses). -This readme might be outdated the best place to get started is the [documentation](http://distilabel.argilla.io/). +> [!WARNING] +> `distilabel` is currently under active development and we're iterating quickly, so take into account that we may introduce breaking changes in the releases during the upcoming weeks, and also the `README` might be outdated the best place to get started is the [documentation](http://distilabel.argilla.io/). -> [!TIP] -> To discuss, get support, give feedback [join Argilla's Slack Community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) +## Motivation -> [!TIP] -> To contribute check our [good first issues](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) or [open a new one](https://github.com/argilla-io/distilabel/issues/new/choose). +🔥 Recent projects like [Zephyr](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-6538c6d6d5ddd1cbb1744a66) and [Tulu](https://huggingface.co/collections/allenai/tulu-v2-suite-6551b56e743e6349aab45101) have shown it's possible to **build powerful open-source models with DPO and AI Feedback** (AIF) datasets. + +👩‍🔬 There's a lot of exciting research in the AIF space, such as [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) (the dataset leveraged by Zephyr and Tulu), [JudgeLM](https://github.com/baaivision/JudgeLM), or [Prometheus](https://huggingface.co/kaist-ai/prometheus-13b-v1.0). + +🚀 However, going beyond research efforts and applying AIF at scale it's different. For enterprise and production use, we need framework that implements **key AIF methods on a robust, efficient and scalable way**. This framework should enable AI engineers to build custom datasets at scale for their own use cases. + +👩‍🎓 This, combined with humans-in-the-loop for improving dataset quality is the next big leap for OSS LLM models. + +⚗️ `distilabel` aims to bridge this gap. + +## Key features + +* 🤖 **Leverage OSS models and APIs**: 🤗 transformers, OpenAI, 🤗 Inference Endpoints, vLLM, llama.cpp, and more to come. + +* 💻 **Scalable and extensible**: Scalable implementations of existing methods (e.g. UltraFeedback). Easily extensible to build and configure your own labellers. + +* 🧑‍🦱 **Human-in-the-loop**: One line of code integration with Argilla to improve and correct datasets. ## Quickstart ### Installation Install with `pip` (requires Python 3.8+): -```sh + +```bash pip install distilabel[openai,argilla] ``` @@ -41,43 +61,41 @@ After installing, you can immediately start experimenting with `distilabel`: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rO1-OlLFPBC0KPuXQOeMpZOeajiwNoMy?usp=sharing) -### Example: build a preference dataset for DPO/RLHF +### Example: Build a preference dataset for DPO/RLHF + ```python from datasets import load_dataset from distilabel.llm import OpenAILLM from distilabel.pipeline import pipeline from distilabel.tasks import TextGenerationTask -# dataset with instructions +# Load a dataset with instructions from the Hub dataset = ( load_dataset("HuggingFaceH4/instruction-dataset", split="test[:5]") .remove_columns(["completion", "meta"]) .rename_column("prompt", "input") ) -# use gpt3.5 turbo for generating responses -task = TextGenerationTask() - +# Use `OpenAILLM` (running `gpt-3.5-turbo`) to generate responses for given inputs generator = OpenAILLM( - task=task, - max_new_tokens=512 - #openai_api_key="sk-.." + task=TextGenerationTask(), + max_new_tokens=512, + # openai_api_key="sk-...", ) -# build preference dataset comparing two responses -# focusing on the instruction-following skill -pipe = pipeline("preference", "instruction-following", generator=generator) +pipeline = pipeline("preference", "instruction-following", generator=generator) -dataset = pipe.generate(dataset, num_generations=2) +# Build a preference dataset comparing two responses focused on the instruction-following skill of the LLM +dataset = pipeline.generate(dataset) ``` -The resulting dataset can already be used for preference tuning (a larger version of it). But beware these AIF dataset are imperfect. To get the most out of AIF feedback, push to Argilla for human feedback: +The resulting dataset can already be used for preference tuning (a larger version of it). But beware these AIF dataset are imperfect. To get the most out of AIF, push to Argilla for human feedback: ```python import argilla as rg rg.init( - api_key="", + api_key="", api_url="" ) @@ -85,109 +103,27 @@ rg_dataset = dataset.to_argilla() rg_dataset.push_to_argilla(name="preference-dataset", workspace="admin") ``` - - https://github.com/argilla-io/distilabel/assets/1107111/be34c95c-8be4-46ef-9437-cbd2a7687e30 +### More examples - -## Motivation -🔥 Recent projects like [Zephyr](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-6538c6d6d5ddd1cbb1744a66) and [Tulu](https://huggingface.co/collections/allenai/tulu-v2-suite-6551b56e743e6349aab45101) have shown it's possible to **build powerful open-source models with DPO and AI Feedback** (AIF) datasets. - -👩‍🔬 There's a lot of exciting research in the AIF space, such as [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) (the dataset leveraged by Zephyr and Tulu), [JudgeLM](https://github.com/baaivision/JudgeLM), or [Prometheus](https://huggingface.co/kaist-ai/prometheus-13b-v1.0). - -🚀 However, going beyond research efforts and applying AIF at scale it's different. For enterprise and production use, we need framework that implements **key AIF methods on a robust, efficient and scalable way**. This framework should enable AI engineers to build custom datasets at scale for their own use cases. - -👩‍🎓 This, combined with humans-in-the-loop for improving dataset quality is the next big leap for OSS LLM models. - -⚗️ `distilabel` aims to bridge this gap. - -## Key features - -* 🤖 **Leverage OSS models and APIs**: HF Transformers, OpenAI, HF Inference Endpoints, vLLM, LlamaCPP, and more to come. - -* 💻 **Scalable and extensible**: Scalable implementations of existing methods (e.g., UltraFeedback). Easily extensible to build and configure your own labelers. - -* 🧑‍🦱 **Human-in-the-loop**: One line of code integration with Argilla to improve and correct datasets. - -## Overview -![distilabel_overview](https://github.com/argilla-io/distilabel/assets/1107111/182c871c-108f-441e-bb3e-f01b080f8631) - +Find more examples of different use cases of `distilabel` under [`examples/`](./examples/). ## Roadmap -- Add Critique Models and support for Prometheus OSS -- Add a generator with multiple models -- Train OSS labelers to replace OpenAI labelers -- Add labelers to evolve instructions generated with self-instruct -- Add labelers for predictive NLP tasks: text classification, information extraction -- Open an issue to suggest a feature! +- [ ] Add Critique Models and support for Prometheus OSS +- [ ] Add a generator with multiple models +- [ ] Train OSS labellers to replace OpenAI labellers +- [ ] Add labellers to evolve instructions generated with self-instruct +- [ ] Add labellers for predictive NLP tasks: text classification, information extraction, etc. +- [ ] Open an issue to suggest a feature! -## How to generate instructions -If you don't have an instruction or prompts dataset you can generate one with our `self-instruct` inspired generator: +## Contribute -```python -import os -from distilabel.tasks import SelfInstructTask -from distilabel.pipeline import Pipeline -from distilabel.llm import OpenAILLM -from datasets import Dataset - -math_topics = [ - "Algebraic Expressions", - "Linear Equations", - "Quadratic Equations", - "Polynomial Functions", - "Rational Expressions", - "Exponential Functions", - "Logarithmic Functions", - "Sequences and Series", - "Matrices", - "Determinants", - #... -] - -dataset = Dataset.from_dict({ - "input": math_topics -}) - -# it will steer the generator -# to generate instructions for this specific app -instruction_task = SelfInstructTask( - application_description= """ - An AI assistant adept at answering a wide array of math, logic, and reasoning puzzles, trivia, and general questions. - """, - num_instructions=10 # 10 instructions per input -) - -# default model is: gpt3.5-turbo -# you can choose gpt-4 too -instruction_generator = OpenAILLM( - task=instruction_task, - openai_api_key=os.getenv("OPENAI_API_KEY"), - num_threads=8, - max_new_tokens=1024 -) +To directly contribute with `distilabel`, check our [good first issues](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) or [open a new one](https://github.com/argilla-io/distilabel/issues/new/choose). -pipeline = Pipeline( - generator=instruction_generator -) +## References -# will generate -distiset = pipeline.generate( - dataset=dataset, - # 10 instruction * 10 generations * 10 inputs = 1000 instructions - num_generations=10, - batch_size=4 -) -# Output: -# Number of generated instructions: 2044 -# 1. Provide an explanation for solving a quadratic equation step by step. -# 2. What is the process for simplifying an algebraic expression with exponents? -# 3. Detail how to factorize a polynomial equation. -# ... -# 10. How can one determine if a given graph represents a linear or quadratic equation? -# 1. How can I simplify the algebraic expression (x^2 + 3x + 2)(2x - 1)? -# 2. Provide step-by-step instructions on how to solve the equation 4(x + 2) - 3 = 7(2x - 1). -# ... -``` +* [UltraFeedback: Boosting Language Models with High-quality Feedback](https://arxiv.org/abs/2310.01377) +* [JudgeLM: Fine-tuned Large Language Models are Scalable Judges](https://arxiv.org/abs/2310.17631) +* [Self-Instruct: Aligning Language Models with Self-Generated Instructions](https://arxiv.org/abs/2212.10560) diff --git a/examples/label-dataset-using-judgelm.py b/examples/label-dataset-using-judgelm.py deleted file mode 100644 index c58a3abf71..0000000000 --- a/examples/label-dataset-using-judgelm.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import argilla as rg -from datasets import load_dataset -from distilabel.llm import OpenAILLM -from distilabel.pipeline import Pipeline -from distilabel.tasks import JudgeLMTask - -os.environ["OPENAI_API_KEY"] = "" -rg.init(api_url="", api_key="") - -dataset = load_dataset("gabrielmbmb/ultrafeedback-prompts-judgelm-gpt35", split="train") - -dataset = dataset.remove_columns( # .shuffle() - [ - "generation_model", - "generation_prompt", - "raw_generation_responses", - "labelling_model", - "labelling_prompt", - "raw_labelling_response", - "ratings", - "rationale", - ] -).select( # type: ignore - range(1) -) - -labeller = OpenAILLM( - model="gpt-3.5-turbo", - task=JudgeLMTask(), - max_new_tokens=1024, - num_threads=16, - temperature=1.0, -) - -pipeline = Pipeline(labeller=labeller) - -labelled_dataset = pipeline.generate( - dataset, # type: ignore - num_generations=2, - batch_size=8, - enable_checkpoints=True, - display_progress_bar=True, -) - -rg_dataset = labelled_dataset.to_argilla() -rg_dataset.push_to_argilla( - name="distilabel-judgelm", workspace="" -) diff --git a/pyproject.toml b/pyproject.toml index 3bd24de489..966adacfff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ hf-inference-endpoints = ["huggingface_hub >= 1.19.0"] llama-cpp = ["llama-cpp >= 0.2.0"] openai = ["openai >= 1.0.0"] vllm = ["vllm >= 0.2.1"] -argilla = ["argilla >= 1.16.0"] +argilla = ["argilla >= 1.18.0"] tests = ["pytest >= 7.4.0"] docs = [ "mkdocs-material >= 9.4.10", diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index 5616d7f662..2a2632685c 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0rc2" +__version__ = "0.1.0" diff --git a/src/distilabel/dataset.py b/src/distilabel/dataset.py index 9483e15102..d03d1d294c 100644 --- a/src/distilabel/dataset.py +++ b/src/distilabel/dataset.py @@ -18,9 +18,6 @@ from distilabel.utils.imports import _ARGILLA_AVAILABLE -if _ARGILLA_AVAILABLE: - import argilla as rg - if TYPE_CHECKING: from argilla import FeedbackDataset @@ -57,13 +54,13 @@ def to_argilla(self) -> "FeedbackDataset": "The task is not set. Please set it with `dataset.task = `." ) - rg_dataset = rg.FeedbackDataset( - fields=self.task.to_argilla_fields(dataset_row=self[0]), - questions=self.task.to_argilla_questions(dataset_row=self[0]), - metadata_properties=self.task.to_argilla_metadata_properties( - dataset_row=self[0] - ), - ) + try: + rg_dataset = self.task.to_argilla_dataset(dataset_row=self[0]) # type: ignore + except Exception as e: + raise ValueError( + f"Error while converting the dataset to an Argilla `FeedbackDataset` instance: {e}" + ) from e + for dataset_row in self: if any( dataset_row[input_arg_name] is None # type: ignore diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 7efbb5ffdd..8cd2819149 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -392,7 +392,7 @@ def _build_dataset( # noqa: C901 processed_labels.extend(future.result()) except Exception as e: logger.error( - f"An error ocurred when getting the result from the labeller: {e}" + f"An error occurred when getting the result from the labeller: {e}" ) processed_labels.append( [ @@ -498,7 +498,7 @@ def generate( # noqa: C901 warnings.warn( f"Provided `num_generations={num_generations}` which implies that the " "`generator` LLM will just run once, while the `labelling` LLM expects " - "to recieve a list of N inputs to label, where N is > 1. If this is not " + "to receive a list of N inputs to label, where N is > 1. If this is not " "intended, make sure to set `num_generations` to a value higher or " "equal to 2.", UserWarning, diff --git a/src/distilabel/tasks/argilla_utils.py b/src/distilabel/tasks/argilla_utils.py new file mode 100644 index 0000000000..fbe17c26ee --- /dev/null +++ b/src/distilabel/tasks/argilla_utils.py @@ -0,0 +1,44 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List + +from distilabel.utils.imports import _ARGILLA_AVAILABLE + +if _ARGILLA_AVAILABLE: + import argilla as rg + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.types import AllowedFieldTypes + + +def infer_fields_from_dataset_row( + field_names: List[str], dataset_row: Dict[str, Any] +) -> List["AllowedFieldTypes"]: + if not _ARGILLA_AVAILABLE: + raise ImportError( + "In order to use any of the functions defined within `argilla_utils` you must install `argilla`" + ) + processed_items = [] + for arg_name in field_names: + if arg_name not in dataset_row: + continue + if isinstance(dataset_row[arg_name], list): + for idx in range(1, len(dataset_row[arg_name]) + 1): + processed_items.append( + rg.TextField(name=f"{arg_name}-{idx}", title=f"{arg_name}-{idx}") + ) # type: ignore + elif isinstance(dataset_row[arg_name], str): + processed_items.append(rg.TextField(name=arg_name, title=arg_name)) # type: ignore + return processed_items diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 778bb93ed2..f0a3f6ac81 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -14,29 +14,15 @@ import importlib.resources as importlib_resources from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Union from jinja2 import Template from distilabel.tasks.prompt import Prompt -try: - import argilla as rg -except ImportError: - pass - if TYPE_CHECKING: - from argilla.client.feedback.schemas.fields import TextField - from argilla.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - ) - from argilla.client.feedback.schemas.questions import RatingQuestion, TextQuestion + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord - from argilla.client.feedback.schemas.types import ( - AllowedFieldTypes, - AllowedQuestionTypes, - ) def get_template(template_name: str) -> str: @@ -45,90 +31,7 @@ def get_template(template_name: str) -> str: ) -class Argilla: - """Class to be used internally to define the methods required to export a dataset - as an Argilla `FeedbackDataset`. - """ - - def to_argilla_fields( - self, - dataset_row: Dict[str, Any], - *args: Any, - **kwargs: Any, - ) -> List["AllowedFieldTypes"]: - raise NotImplementedError( - "`to_argilla_fields` is not implemented, if you want to export your dataset as an Argilla dataset you will need to implement this method." - ) - - def to_argilla_questions( - self, - dataset_row: Dict[str, Any], - *args: Any, - **kwargs: Any, - ) -> List["AllowedQuestionTypes"]: - raise NotImplementedError( - "`to_argilla_questions` is not implemented, if you want to export your dataset as an Argilla dataset you will need to implement this method." - ) - - def to_argilla_record( - self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any - ) -> "FeedbackRecord": - raise NotImplementedError( - "`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla dataset you will need to implement this method." - ) - - def _create_argilla_record( - self, - fields: Dict[str, Any], - suggestions: List[Dict[str, Any]], - metadata: Dict[str, Any], - ) -> "FeedbackRecord": - return rg.FeedbackRecord( - fields=fields, suggestions=suggestions, metadata=metadata - ) - - def _create_text_field(self, name: str) -> "TextField": - return rg.TextField(name=name) - - def _create_rating_question( - self, name: str, title: str, values: List[int] - ) -> "RatingQuestion": - return rg.RatingQuestion(name=name, title=title, values=values) - - def _create_text_question(self, name: str, title: str) -> "TextQuestion": - return rg.TextQuestion(name=name, title=title) - - def _create_metadata_property( - self, name: str, property_type: str - ) -> Union["IntegerMetadataProperty", "FloatMetadataProperty"]: - if property_type == "integer": - return rg.IntegerMetadataProperty(name=name) - elif property_type == "float": - return rg.FloatMetadataProperty(name=name) - else: - raise ValueError(f"Invalid property type: {property_type}") - - def _create_fields_from_row( - self, dataset_row: Dict[str, Any], process_function: Callable - ) -> List["AllowedFieldTypes"]: - processed_items = [] - for arg_name in self.input_args_names: - self._check_argument_exists(dataset_row, arg_name) - if isinstance(dataset_row[arg_name], list): - for idx in range(1, len(dataset_row[arg_name]) + 1): - processed_items.append(process_function(f"{arg_name}-{idx}")) - elif isinstance(dataset_row[arg_name], str): - processed_items.append(process_function(arg_name)) - return processed_items - - def _check_argument_exists(self, dataset_row, arg_name): - if arg_name not in dataset_row: - raise ValueError( - f"Dataset row does not contain the required field '{arg_name}'." - ) - - -class Task(ABC, Argilla): +class Task(ABC): """Abstract class used to define the methods required to create a `Task`, to be used within an `LLM`. @@ -193,3 +96,19 @@ def validate_dataset(self, columns_in_dataset: List[str]) -> None: f"LLM expects a column named '{input_arg_name}' in the provided" " dataset, but it was not found." ) + + def to_argilla_dataset( + self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any + ) -> "FeedbackDataset": + raise NotImplementedError( + "`to_argilla_dataset` is not implemented, if you want to export your dataset as an Argilla" + " `FeedbackDataset` you will need to implement this method first." + ) + + def to_argilla_record( + self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any + ) -> "FeedbackRecord": + raise NotImplementedError( + "`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla" + " `FeedbackDataset` you will need to implement this method first." + ) diff --git a/src/distilabel/tasks/preference/base.py b/src/distilabel/tasks/preference/base.py index ddaa7f8371..fc4753784a 100644 --- a/src/distilabel/tasks/preference/base.py +++ b/src/distilabel/tasks/preference/base.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from distilabel.tasks.argilla_utils import infer_fields_from_dataset_row from distilabel.tasks.base import Task +from distilabel.utils.imports import _ARGILLA_AVAILABLE + +if _ARGILLA_AVAILABLE: + import argilla as rg if TYPE_CHECKING: + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord - from argilla.client.feedback.schemas.types import ( - AllowedFieldTypes, - AllowedMetadataPropertyTypes, - AllowedQuestionTypes, - ) + @dataclass class PreferenceTask(Task): @@ -44,129 +47,161 @@ def output_args_names(self) -> List[str]: """Returns the names of the output arguments of the task.""" return ["rating", "rationale"] - def to_argilla_fields( - self, dataset_row: Dict[str, Any] - ) -> List["AllowedFieldTypes"]: - """Converts a dataset row to a list of Argilla `AllowedFieldTypes`.""" - return self._create_fields_from_row(dataset_row, self._create_text_field) - - def to_argilla_questions( - self, dataset_row: Dict[str, Any] - ) -> List["AllowedQuestionTypes"]: - """Converts a dataset row to a list of Argilla `AllowedQuestionTypes`.""" + def to_argilla_dataset( + self, + dataset_row: Dict[str, Any], + generations_column: Optional[str] = "generations", + ratings_column: Optional[str] = "rating", + ratings_values: Optional[List[int]] = None, + rationale_column: Optional[str] = "rationale", + ) -> "FeedbackDataset": + # First we infer the fields from the input_args_names, but we could also + # create those manually instead using `rg.TextField(...)` + fields = infer_fields_from_dataset_row( + field_names=self.input_args_names, dataset_row=dataset_row + ) + # Then we add the questions, which cannot be easily inferred in this case, + # because those depend neither on the outputs nor on the inputs, but in a combination + # of both, since the questions will be formulated using the inputs, but assigned to the + # outputs. + if generations_column is None or generations_column not in dataset_row: + raise ValueError( + f"The `generations_column='{generations_column}'` is not present in the dataset" + f" row. Please provide any of {list(dataset_row.keys())}.", + ) + if ratings_column is None or ratings_column not in dataset_row: + raise ValueError( + f"The `ratings_column='{ratings_column}'` is not present in the dataset row. Please" + f" provide any of {list(dataset_row.keys())}.", + ) + if rationale_column is None or rationale_column not in dataset_row: + raise ValueError( + f"The `rationale_column='{rationale_column}'` is not present in the dataset row. Please" + f" provide any of {list(dataset_row.keys())}.", + ) questions = [] - arg_name = "generations" - self._check_argument_exists(dataset_row, arg_name) - if isinstance(dataset_row[arg_name], list): - for idx in range(1, len(dataset_row[arg_name]) + 1): - question_name = f"{arg_name}-{idx}-rating" - title = f"What's the rating for {arg_name}-{idx}?" - questions.append( - self._create_rating_question( - question_name, title, list(range(1, 11)) - ) + for idx in range(1, len(dataset_row[generations_column]) + 1): + questions.append( + rg.RatingQuestion( # type: ignore + name=f"{generations_column}-{idx}-{ratings_column}", + title=f"What's the {ratings_column} for {generations_column}-{idx}?", + values=ratings_values or [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ) + ) questions.append( - self._create_text_question( - "ratings-rationale", "What's the rationale behind the ratings?" + rg.TextQuestion( # type: ignore + name=f"{ratings_column}-{rationale_column}", + title=f"What's the {rationale_column} behind each {ratings_column}?", ) ) - return questions - - def to_argilla_metadata_properties( - self, dataset_row: Dict[str, Any] - ) -> List["AllowedMetadataPropertyTypes"]: - """Converts a dataset row to a list of Argilla `AllowedMetadataPropertyTypes`.""" + # Finally, we define some metadata properties that can be potentially used + # while exploring the dataset within Argilla to get more insights on the data. metadata_properties = [] for arg_name in self.input_args_names: - self._check_argument_exists(dataset_row, arg_name) if isinstance(dataset_row[arg_name], list): for idx in range(1, len(dataset_row[arg_name]) + 1): metadata_properties.append( - self._create_metadata_property( - f"length-{arg_name}-{idx}", "integer" - ) + rg.IntegerMetadataProperty(name=f"length-{arg_name}-{idx}") # type: ignore ) - metadata_properties.append( - self._create_metadata_property( - f"rating-{arg_name}-{idx}", "float" + if arg_name == generations_column: + metadata_properties.append( + rg.FloatMetadataProperty( + name=f"{ratings_column}-{arg_name}-{idx}" + ) # type: ignore ) - ) elif isinstance(dataset_row[arg_name], str): metadata_properties.append( - self._create_metadata_property(f"length-{arg_name}", "integer") + rg.IntegerMetadataProperty(name=f"length-{arg_name}") # type: ignore ) else: - raise ValueError( - f"Type {type(dataset_row[arg_name])} is not supported." + warnings.warn( + f"Unsupported input type ({type(dataset_row[arg_name])}), skipping...", + UserWarning, + stacklevel=2, ) - # add distance between best rating and the second best - if isinstance(dataset_row[arg_name], list): - metadata_properties.append( - self._create_metadata_property("distance-best-rated", "float") - ) - return metadata_properties + metadata_properties.append( + rg.FloatMetadataProperty(name=f"distance-best-{ratings_column}") # type: ignore + ) + # Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties + # defined above. + return rg.FeedbackDataset( + fields=fields, + questions=questions, + metadata_properties=metadata_properties, # Note that these are always optional + ) + + def _merge_rationales( + self, rationales: List[str], generations_column: str = "generations" + ) -> str: + return "".join( + f"{generations_column}-{idx}:\n{rationale}\n" + for idx, rationale in enumerate(rationales, start=1) + ) def to_argilla_record( # noqa: C901 self, dataset_row: Dict[str, Any], + generations_column: Optional[str] = "generations", + ratings_column: Optional[str] = "rating", + rationale_column: Optional[str] = "rationale", ) -> "FeedbackRecord": """Converts a dataset row to an Argilla `FeedbackRecord`.""" - fields = {} - metadata = {} - - for input_arg_name in self.input_args_names: - arg_value = dataset_row[input_arg_name] - + # We start off with the fields, which are the inputs of the LLM, but also + # build the metadata from them, as previously specified within the + fields, metadata = {}, {} + for arg_name in self.input_args_names: + arg_value = dataset_row[arg_name] if isinstance(arg_value, list): for idx, value in enumerate(arg_value, start=1): - fields[f"{input_arg_name}-{idx}"] = value.strip() - metadata[f"length-{input_arg_name}-{idx}"] = len(value.strip()) + fields[f"{arg_name}-{idx}"] = value.strip() if value else "" + if value is not None: + metadata[f"length-{arg_name}-{idx}"] = len(value.strip()) + elif isinstance(arg_value, str): + fields[arg_name] = arg_value.strip() if arg_value else "" + if arg_value is not None: + metadata[f"length-{arg_name}"] = len(arg_value.strip()) else: - fields[input_arg_name] = arg_value.strip() - metadata[f"length-{input_arg_name}"] = len(arg_value.strip()) - + warnings.warn( + f"Unsupported input type ({type(arg_value)}), skipping...", + UserWarning, + stacklevel=2, + ) + # Then we include the suggestions, which are generated from the outputs + # of the LLM instead. suggestions = [] - - # add rationale - if self._to_argilla_rationale(dataset_row) is not None: + if rationale_column is None or rationale_column not in dataset_row: + raise ValueError( + f"The rationale column {rationale_column} is not present in the dataset row." + ) + if dataset_row.get(rationale_column) is not None: + rationales = dataset_row.get(rationale_column) suggestions.append( { - "question_name": "ratings-rationale", - "value": self._to_argilla_rationale(dataset_row), + "question_name": f"{ratings_column}-{rationale_column}", + "value": self._merge_rationales(rationales=rationales) + if isinstance(rationales, list) + else rationales, } ) - for output_arg_name in self.output_args_names: - if output_arg_name == "rating": - ratings = [] - output_data = dataset_row.get(output_arg_name) - if output_data is not None: - for idx, value in enumerate(output_data, start=1): - ratings.append(value) - if value <=0: - value = 1.0 - if value <= 10: - # add suggestions - suggestions.append( - { - "question_name": f"generations-{idx}-rating", - "value": int(value), - } - ) - # update rating metadata - metadata.update({f"rating-generations-{idx}": value}) - if len(ratings) >= 2: - sorted_ratings = sorted(ratings, reverse=True) - # update rating distance from best to second - metadata.update( - {"distance-best-rated": sorted_ratings[0] - sorted_ratings[1]} - ) - return self._create_argilla_record( + if ratings_column is None or ratings_column not in dataset_row: + raise ValueError( + f"The ratings column {ratings_column} is not present in the dataset row." + ) + if dataset_row.get(ratings_column) is not None: + ratings = dataset_row.get(ratings_column) + for idx, value in enumerate(ratings, start=1): # type: ignore + suggestions.append( + { + "question_name": f"{generations_column}-{idx}-{ratings_column}", + "value": 1 if value < 1 else int(value) if value < 10 else None, + } + ) + metadata[f"{ratings_column}-{generations_column}-{idx}"] = value + if len(ratings) >= 2: # type: ignore + sorted_ratings = sorted(ratings, reverse=True) # type: ignore + metadata[f"distance-best-{ratings_column}"] = ( + sorted_ratings[0] - sorted_ratings[1] + ) + return rg.FeedbackRecord( fields=fields, suggestions=suggestions, metadata=metadata ) - - def _to_argilla_rationale(self, dataset_row: Dict[str, Any]) -> str: - """Gets the `rationale` column from a `datasets.Dataset` row and formats it - as expected by Argilla. - """ - return dataset_row["rationale"] diff --git a/src/distilabel/tasks/preference/ultrafeedback.py b/src/distilabel/tasks/preference/ultrafeedback.py index 30da9d30e9..2d89d91171 100644 --- a/src/distilabel/tasks/preference/ultrafeedback.py +++ b/src/distilabel/tasks/preference/ultrafeedback.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from textwrap import dedent -from typing import Any, ClassVar, Dict, List, Optional +from typing import ClassVar, List, Optional from typing_extensions import TypedDict @@ -109,19 +109,6 @@ def parse_output(self, output: str) -> List[UltraFeedbackOutput]: ) return parsed_output - def _to_argilla_rationale( - self, - dataset_row: Dict[str, Any], - ) -> str: - """Converts the rationale to the format expected by Argilla.""" - rationales = dataset_row.get("rationale") - if rationales is None: - return "" - sections = [] - for idx, rationale in enumerate(dataset_row["rationale"], start=1): - sections.append(f"Rationale for generation-{idx}:\n{rationale}\n") - return "\n".join(sections) - @classmethod def for_text_quality( cls, diff --git a/src/distilabel/tasks/preference/ultrajudge.py b/src/distilabel/tasks/preference/ultrajudge.py index c0b1d35ab9..421fa57333 100644 --- a/src/distilabel/tasks/preference/ultrajudge.py +++ b/src/distilabel/tasks/preference/ultrajudge.py @@ -39,7 +39,7 @@ class UltraJudgeOutput(TypedDict): @dataclass class UltraJudgeTask(PreferenceTask): """A `PreferenceTask` for the UltraJudge task. The `UltraJudge` task has been defined - at Argilla specically for a better evaluation using AI Feedback. The task is defined + at Argilla specifically for a better evaluation using AI Feedback. The task is defined based on both UltraFeedback and JudgeLM, but with several improvements / modifications. Args: @@ -155,15 +155,12 @@ def parse_output(self, output: str) -> List[UltraJudgeOutput]: return outputs - def _to_argilla_rationale( - self, - dataset_row: Dict[str, Any], + def _merge_rationales( + self, rationales: List[Dict[str, Any]], generations_column: str = "generations" ) -> str: - """Gets the `rationale` column from a `datasets.Dataset` row and formats it - as expected by Argilla. - """ + """Overwrite of the `_merge_rationales` as we need to process the areas before merging.""" - def format_area(area): + def format_area(area: Dict[str, Any]) -> str: sections = [] for title, ratings in area.items(): sections.append(title) @@ -171,8 +168,9 @@ def format_area(area): sections.append(f"{k}:{v}") return "\n".join(sections) - rationales = [] - for idx, area in enumerate(dataset_row["areas"], start=1): - formatted_area = format_area(area) - rationales.append(f"Rationale for generation-{idx}:\n{formatted_area}\n") - return "\n".join(rationales) + merged_rationales = [] + for idx, area in enumerate(rationales, start=1): + merged_rationales.append( + f"{generations_column}-{idx}:\n{format_area(area)}\n" + ) + return "\n".join(merged_rationales)