Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new functionality to binarize preference datasets directly from distilabel #264

Merged
merged 10 commits into from
Jan 19, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from distilabel.tasks import JudgeLMTask
from distilabel.dataset import prepare_dataset

dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dataset.task = JudgeLMTask()
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=True)
# >>> len(dataset)
# 12859
# >>> len(dataset_binarized_random)
# 12817
dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=False)
# >>> len(dataset_binarized_random)
# 8850
22 changes: 22 additions & 0 deletions docs/technical-reference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ The dataset can be regenerated from the checkpoint by simply calling the `Custom

And with the dataset regenerated we can easily call `push_to_argilla` on it to review it.

### Prepare datasets for fine-tuning

The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format).

`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*.
plaguss marked this conversation as resolved.
Show resolved Hide resolved

By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are kept in the dataset, but those should be removed for fine-tuning. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information.

!!! Binarization

=== "random"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py"
```

=== "worst"

```python
--8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py"
```

## pipeline

Considering recurring patterns in dataset creation, we can facilitate the process by utilizing the [`Pipeline`][distilabel.pipeline.Pipeline]. This is made simpler through the [`pipeline`][distilabel.pipeline.pipeline] function, which provides the necessary parameters for creating a `Pipeline`.
Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.utils.dataset import prepare_dataset

__all__ = ["prepare_dataset"]
223 changes: 222 additions & 1 deletion src/distilabel/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal, Optional, get_args

import dill as pickle

from distilabel.logger import get_logger

if TYPE_CHECKING:
from distilabel.dataset import CustomDataset
from distilabel.tasks.base import Task


TASK_FILE_NAME = "task.pkl"

logger = get_logger()

BinarizationStrategies = Literal["random", "worst"]


def save_task_to_disk(path: Path, task: "Task") -> None:
"""Saves a task to disk.
Expand Down Expand Up @@ -51,3 +60,215 @@ def load_task_from_disk(path: Path) -> "Task":
with open(task_path, "rb") as f:
task = pickle.loads(f.read())
return task


def _binarize_dataset(
dataset: "CustomDataset",
seed: int = None,
strategy: BinarizationStrategies = "random",
keep_ties: bool = True,
plaguss marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> "CustomDataset":
"""Binarizes a distilabel dataset.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, some people might be like: "what is binarizing?"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you both @davidberenstein1957 and @dvsrepo take a look at the new section of the docs for the dataset preparation/binarization? This commit e4602d4 contains the updates.


Args:
dataset (CustomDataset): The distilabel dataset to binarize.
seed (int, optional): Random seed. Defaults to 42.
strategy (BinarizationStrategies, optional): Method to binarize the data. Defaults to "random".
keep_ties (bool, optional):
Whether to keep ties in case the binarization method generated the chosen
and rejected responses to have the same rating. Defaults to True.
kwargs: Extra parameters passed to `datasets.Dataset.map`.

Raises:
ValueError: If the strategy is not implemented.

Returns:
CustomDataset: Dataset binarized.
"""
rating_column = "rating"
responses_column = "generations"

def binarize_random(example):
plaguss marked this conversation as resolved.
Show resolved Hide resolved
plaguss marked this conversation as resolved.
Show resolved Hide resolved
random.seed(seed)

# First pick the highest rating
prompt = example["input"]
best_rating = max(example[rating_column])
best_response_idx = example[rating_column].index(best_rating)
chosen_response = example[responses_column][best_response_idx]
chosen_model = example["generation_model"][best_response_idx]

# Remove best response
example[rating_column].pop(best_response_idx)
example[responses_column].pop(best_response_idx)
example["generation_model"].pop(best_response_idx)

# Then you pick the rejected from the list of candidates with lower scores.
example_lower = defaultdict(list)
for i, rating in enumerate(example[rating_column]):
if rating < best_rating:
example_lower[responses_column].append(example[responses_column][i])
example_lower[rating_column].append(rating)

# Otherwise you declare that a tie
if len(example_lower[rating_column]) == 0:
# In this case we don't have any response with a lower rating, so we just
# let the original example (we have a tie)
example_lower = example

random_response = random.choice(example_lower[responses_column])
random_response_idx = example_lower[responses_column].index(random_response)
random_rating = example_lower[rating_column][random_response_idx]

random_model = example["generation_model"][random_response_idx]

binarized = {
"prompt": prompt,
"chosen": chosen_response,
"rejected": random_response,
"rating_chosen": int(best_rating),
"rating_rejected": int(random_rating),
"chosen_model": chosen_model,
"rejected_model": random_model,
}
return binarized

def binarize_worst(example):
plaguss marked this conversation as resolved.
Show resolved Hide resolved
random.seed(seed)

prompt = example["input"]
best_rating = max(example[rating_column])
best_response_idx = example[rating_column].index(best_rating)
chosen_response = example[responses_column][best_response_idx]

chosen_model = example["generation_model"][best_response_idx]

worst_rating = min(example[rating_column])
worst_response_idx = example[rating_column].index(worst_rating)
worst_response = example[responses_column][worst_response_idx]

worst_model = example["generation_model"][worst_response_idx]

binarized = {
"prompt": prompt,
"chosen": chosen_response,
"rejected": worst_response,
"rating_chosen": int(best_rating),
"rating_rejected": int(worst_rating),
"chosen_model": chosen_model,
"rejected_model": worst_model,
}
return binarized

if strategy == "random":
binarization_method = binarize_random
elif strategy == "worst":
binarization_method = binarize_worst
else:
raise ValueError(
f"Strategy `{strategy}` is not implemented, it must be one of: {get_args(BinarizationStrategies)}"
)

if "generation_model" not in dataset.column_names:
# Ensure generation model is found in the dataset, even if empty, to avoid
# erros when calling map
dataset = dataset.add_column(
"generation_model", [[""] * len(dataset[0]["generations"])] * len(dataset)
)

dataset = dataset.map(binarization_method, **kwargs)

if not keep_ties:
dataset = dataset.filter(
lambda example: example["rating_chosen"] != example["rating_rejected"]
)
return dataset


def prepare_dataset(
dataset: "CustomDataset",
strategy: BinarizationStrategies = "random",
seed: Optional[int] = None,
keep_ties: bool = True,
plaguss marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> "CustomDataset":
"""Helper function to prepare a distilabel dataset for training with the standard formats.

Currently supports the `PreferenceTask`, and binarizes the responses assuming
one of two strategies:

- `random`: Selects the *chosen* response based on the highest rating, and for the
*rejected* selects a random response from the remaining ones. Filters the examples in which
the chosen rating is equal to the rejected one.
- `worst`: Selects the *chosen* response based on the highest rating, and for the
*rejected* selects the response with the lowest rating. Filters the examples in which the
chosen rating is equal to the rejected one.

Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences)
for more information on binarizing a dataset to prepare it for DPO fine-tuning.

Expected format for a dataset to be trained with DPO as defined in trl's
[dpo trainer](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format).

Args:
dataset (CustomDataset):
CustomDataset with a PreferenceTask to prepare for Direct Preference Optimization.
strategy (BinarizationStrategies, optional):
Strategy to binarize the data. Defaults to "random".
seed (int, optional): Seed for the random generator, in case of `random` strategy. Defaults to None.
keep_ties (bool, optional):
Whether to keep ties in case the binarization method generated the chosen
and rejected responses to have the same rating. Defaults to True.
kwargs: Extra parameters passed to `datasets.Dataset.map`.

Returns:
CustomDataset: Dataset formatted for training with DPO.

Examples:
>>> from datasets import load_dataset
>>> from distilabel.tasks import UltraFeedbackTask
>>> import os
>>> dataset = load_dataset("argilla/DistiCoder-dpo", token=os.getenv("HF_API_TOKEN"), split="train")
>>> dataset.task = UltraFeedbackTask.for_instruction_following()
>>> dataset_binarized = prepare_dataset(dataset, strategy="worst")
"""
from distilabel.tasks.preference.base import PreferenceTask

if not isinstance(dataset.task, PreferenceTask):
raise ValueError(
"This functionality is currently implemented for `PreferenceTask` only."
)

remove_columns = [
"input",
"generation_model",
"generations",
"rating",
"labelling_model",
"labelling_prompt",
"raw_labelling_response",
"rationale",
]
# Remove the rows for which there is no rating
initial_length = len(dataset)
dataset = dataset.filter(lambda example: example["rating"])
if len(dataset) != initial_length:
logger.info(
f"Found {initial_length - len(dataset)} examples with no rating, removing them."
)

if len(dataset[0]["generations"]) < 2:
raise ValueError("The dataset must contain at least 2 generations per example.")

ds = _binarize_dataset(
dataset, strategy=strategy, seed=seed, keep_ties=keep_ties, **kwargs
)

# Imported here to avoid circular imports
from distilabel.dataset import CustomDataset

ds = ds.remove_columns(remove_columns)
ds.__class__ = CustomDataset
ds.task = dataset.task
return ds
Loading