Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new functionality to binarize preference datasets directly from distilabel #264

Merged
merged 10 commits into from
Jan 19, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
76 changes: 76 additions & 0 deletions docs/technical-reference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,79 @@ The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI m
```python
--8<-- "docs/snippets/technical-reference/pipeline/argilla.py"
```

## Prepare datasets for fine-tuning

The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment or instruction fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format) (initially we only cover the case for *DPO*).

`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction of *dataset binarization*.

By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it want's to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information.

!!! Binarization

=== "random"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py"
```

=== "worst"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py"
```

### What's binarization?

In the context of preference datasets (datasets for LLM instruction-tuning) one can come up with datasets formatted following the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) format (the same format one obtains from a `Pipeline` that labels a dataset with a [`PreferenceTask`][distilabel.tasks.preference.base.PreferenceTask]), where for a given instruction we can have multiple completions according to one or more models, rated either by humans or other LLMs.

From distilabel, we would obtain from a labelling `Pipeline` a dataset with the following format:

| input | generations | rating |
|:--------------------------------------------------------|:-----------------------------------|---------:|
| Generate an approximately fifteen-word sentence that... | [Midsummer House is a moderately..., Sure! Here's a sentence that...] | [9.0, 7.0] |

Where each columns represents the following:

- **input**: Input for the LLM to generate text.

- **generations**: List of generations from the LLM (maybe an [LLMPool][distilabel.llm.base.LLMPool] with different models).

- **rating**: A list of the ratings for each of the generations obtained by an LLM using one of the `PreferenceTasks`, like [JudgeLMTask][distilabel.tasks.preference.judgelm.JudgeLMTask] or [UltraFeedbackTask][distilabel.tasks.preference.ultrafeedback.UltraFeedbackTask]

This dataset format contains all the raw information, but in order to use it in the common frameworks, the expected format is usually a prompt, a chosen and a rejected response to align the model with those preferences.

We would want the following dataset format for fine-tuning:

| prompt | chosen | rejected |
|:--------------------------------------------------------|:-----------------------------------|---------:|
| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] |

Take a look at this [explanation](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences#dataset-processing) for the binarization of *UltraFeedback* done to train [Notus-7B-v1](https://huggingface.co/argilla/notus-7b-v1).

What does each column represents.

- **prompt**: Instruction given to the model.

- **chosen**: Response chosen following the OpenAI format.

- **rejected**: Response rejected following the OpenAI format.

We refer to the [OpenAI's chat format](https://platform.openai.com/docs/guides/text-generation) for more information on the chosen/rejected format.

This dataset processing is called binarization. In the context of `distilabel`, this transformation (dataset prepartion) is done by [`dataset.utils.prepare_dataset`][distilabel.utils.dataset.prepare_dataset], and given that the generated datasets contain additional information, one can also see the following additional columns:

| prompt | chosen | rejected | rating_chosen | rating_rejected | chosen_model | rejected_model |
|:--------------------------------------------------------|:-----------------------------------|:--------------------------------|----------------:|------------------:|:---------------|:-----------------|
| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] | 9 | 7 | | |

- **rating_chosen**: Rating of the chosen instruction.

- **rating_rejected**: Rating of the rejected instruction.

- **chosen_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the chosen instruction.

- **rejected_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the rejected instruction.

Need more information? Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*.
4 changes: 4 additions & 0 deletions src/distilabel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.utils.dataset import prepare_dataset

__all__ = ["prepare_dataset"]
Loading