Skip to content

Commit

Permalink
Add new functionality to binarize preference datasets directly from d…
Browse files Browse the repository at this point in the history
…istilabel (#264)

* Add prepare dataset function to binarize preference datasets
  • Loading branch information
plaguss authored Jan 19, 2024
1 parent 5bf0fe4 commit c7fd1ad
Show file tree
Hide file tree
Showing 6 changed files with 608 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
76 changes: 76 additions & 0 deletions docs/technical-reference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,79 @@ The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI m
```python
--8<-- "docs/snippets/technical-reference/pipeline/argilla.py"
```

## Prepare datasets for fine-tuning

The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment or instruction fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format) (initially we only cover the case for *DPO*).

`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction of *dataset binarization*.

By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it want's to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information.

!!! Binarization

=== "random"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py"
```

=== "worst"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py"
```

### What's binarization?

In the context of preference datasets (datasets for LLM instruction-tuning) one can come up with datasets formatted following the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) format (the same format one obtains from a `Pipeline` that labels a dataset with a [`PreferenceTask`][distilabel.tasks.preference.base.PreferenceTask]), where for a given instruction we can have multiple completions according to one or more models, rated either by humans or other LLMs.

From distilabel, we would obtain from a labelling `Pipeline` a dataset with the following format:

| input | generations | rating |
|:--------------------------------------------------------|:-----------------------------------|---------:|
| Generate an approximately fifteen-word sentence that... | [Midsummer House is a moderately..., Sure! Here's a sentence that...] | [9.0, 7.0] |

Where each columns represents the following:

- **input**: Input for the LLM to generate text.

- **generations**: List of generations from the LLM (maybe an [LLMPool][distilabel.llm.base.LLMPool] with different models).

- **rating**: A list of the ratings for each of the generations obtained by an LLM using one of the `PreferenceTasks`, like [JudgeLMTask][distilabel.tasks.preference.judgelm.JudgeLMTask] or [UltraFeedbackTask][distilabel.tasks.preference.ultrafeedback.UltraFeedbackTask]

This dataset format contains all the raw information, but in order to use it in the common frameworks, the expected format is usually a prompt, a chosen and a rejected response to align the model with those preferences.

We would want the following dataset format for fine-tuning:

| prompt | chosen | rejected |
|:--------------------------------------------------------|:-----------------------------------|---------:|
| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] |

Take a look at this [explanation](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences#dataset-processing) for the binarization of *UltraFeedback* done to train [Notus-7B-v1](https://huggingface.co/argilla/notus-7b-v1).

What does each column represents.

- **prompt**: Instruction given to the model.

- **chosen**: Response chosen following the OpenAI format.

- **rejected**: Response rejected following the OpenAI format.

We refer to the [OpenAI's chat format](https://platform.openai.com/docs/guides/text-generation) for more information on the chosen/rejected format.

This dataset processing is called binarization. In the context of `distilabel`, this transformation (dataset prepartion) is done by [`dataset.utils.prepare_dataset`][distilabel.utils.dataset.prepare_dataset], and given that the generated datasets contain additional information, one can also see the following additional columns:

| prompt | chosen | rejected | rating_chosen | rating_rejected | chosen_model | rejected_model |
|:--------------------------------------------------------|:-----------------------------------|:--------------------------------|----------------:|------------------:|:---------------|:-----------------|
| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] | 9 | 7 | | |

- **rating_chosen**: Rating of the chosen instruction.

- **rating_rejected**: Rating of the rejected instruction.

- **chosen_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the chosen instruction.

- **rejected_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the rejected instruction.

Need more information? Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*.
4 changes: 4 additions & 0 deletions src/distilabel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.utils.dataset import prepare_dataset

__all__ = ["prepare_dataset"]
Loading

0 comments on commit c7fd1ad

Please sign in to comment.