From 5a094d6dbca9d90ceb4443ad27968306d80129db Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 16 Jan 2024 13:54:36 +0100 Subject: [PATCH 1/9] Add prepare dataset function to binarize preference datasets --- src/distilabel/utils/dataset.py | 131 +++++++++++++++++++++++++++++++- 1 file changed, 130 insertions(+), 1 deletion(-) diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index d92f3c8575..059897b978 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -12,18 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, get_args import dill as pickle +from distilabel.tasks.preference.base import PreferenceTask + if TYPE_CHECKING: + from distilabel.dataset import CustomDataset from distilabel.tasks.base import Task TASK_FILE_NAME = "task.pkl" +BinarizationStrategies = Literal["random", "worst"] + + def save_task_to_disk(path: Path, task: "Task") -> None: """Saves a task to disk. @@ -51,3 +58,125 @@ def load_task_from_disk(path: Path) -> "Task": with open(task_path, "rb") as f: task = pickle.loads(f.read()) return task + + +def _binarize_dataset( + dataset: "CustomDataset", + seed: int = 42, + strategy: BinarizationStrategies = "random", +) -> "CustomDataset": + rating_column = "rating" + responses_column = "generations" + + def binarize_random(example): + random.seed(seed) + + prompt = example["input"] + best_rating = max(example[rating_column]) + best_response_idx = example[rating_column].index(best_rating) + chosen_response = example[responses_column][best_response_idx] + chosen_model = example["generation_model"][best_response_idx] + + # Remove best response + example[rating_column].pop(best_response_idx) + example[responses_column].pop(best_response_idx) + example["generation_model"].pop(best_response_idx) + + # Select the random response + random_response = random.choice(example[responses_column]) + random_response_idx = example[responses_column].index(random_response) + random_rating = example[rating_column][random_response_idx] + random_model = example["generation_model"][random_response_idx] + + binarized = { + "prompt": prompt, + "chosen": chosen_response, + "rejected": random_response, + "rating_chosen": int(best_rating), + "rating_rejected": int(random_rating), + "chosen_model": chosen_model, + "rejected_model": random_model, + } + return binarized + + def binarize_worst(example): + random.seed(seed) + + prompt = example["input"] + best_rating = max(example[rating_column]) + best_response_idx = example[rating_column].index(best_rating) + chosen_response = example[responses_column][best_response_idx] + chosen_model = example["generation_model"][best_response_idx] + + worst_rating = min(example[rating_column]) + worst_response_idx = example[rating_column].index(worst_rating) + worst_response = example[responses_column][worst_response_idx] + worst_model = example["generation_model"][worst_response_idx] + + binarized = { + "prompt": prompt, + "chosen": chosen_response, + "rejected": worst_response, + "rating_chosen": int(best_rating), + "rating_rejected": int(worst_rating), + "chosen_model": chosen_model, + "rejected_model": worst_model, + } + return binarized + + if strategy == "random": + binarization_method = binarize_random + elif strategy == "worst": + binarization_method = binarize_worst + else: + raise ValueError( + f"Strategy `{strategy}` is not implemented, it must be one of: {get_args(BinarizationStrategies)}" + ) + + return dataset.map(binarization_method).filter( + lambda example: example["rating_chosen"] != example["rating_rejected"] + ) + + +def prepare_dataset( + dataset: "CustomDataset", + strategy: BinarizationStrategies = "random", + seed: int = 42, +) -> "CustomDataset": + """Helper function to prepare a dataset for training assuming the standard formats. + + Expected format for a dataset to be trained with DPO as defined in trl's + [dpo trainer](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format). + + Args: + dataset (CustomDataset): Dataset with a PreferenceTask. + strategy (BinarizationStrategies, optional): + Strategy to binarize the data. Defaults to "random". + + Returns: + CustomDataset: Dataset formatted for training with DPO. + """ + if not isinstance(dataset.task, PreferenceTask): + raise ValueError( + "This functionality is currently implemented for `PreferenceTask` only." + ) + + remove_columns = [ + "input", + "generation_model", + "generations", + "rating", + "labelling_model", + "labelling_prompt", + "raw_labelling_response", + "rationale", + ] + + ds = _binarize_dataset(dataset, strategy=strategy, seed=42) + + from distilabel.dataset import CustomDataset + + ds = ds.remove_columns(remove_columns) + ds.__class__ = CustomDataset + ds.task = dataset.task + return ds From 9b04ee804c3d8c0e0de18a01abb920d82b8a78fb Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 16 Jan 2024 13:55:24 +0100 Subject: [PATCH 2/9] Add tests for prepare_dataset --- tests/test_dataset.py | 169 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 86a246bbd4..ed04ceb8f1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import tempfile from pathlib import Path +from typing import List import pytest from distilabel.dataset import CustomDataset, DatasetCheckpoint -from distilabel.tasks import UltraFeedbackTask +from distilabel.tasks import TextGenerationTask, UltraFeedbackTask +from distilabel.utils.dataset import prepare_dataset @pytest.fixture @@ -27,6 +30,96 @@ def custom_dataset(): return ds +@pytest.fixture +def sample_preference_dataset(): + ds = CustomDataset.from_dict( + { + "input": ["input 1", "input 2", "input 3"], + "generation_model": [ + [ + "argilla/notus-7b-v1", + "WizardLM/WizardCoder-15B-V1.0", + "ise-uiuc/Magicoder-S-DS-6.7B", + "gpt-3.5-turbo", + ], + [ + "argilla/notus-7b-v1", + "ise-uiuc/Magicoder-S-DS-6.7B", + "WizardLM/WizardCoder-15B-V1.0", + "gpt-3.5-turbo", + ], + [ + "argilla/notus-7b-v1", + "ise-uiuc/Magicoder-S-DS-6.7B", + "WizardLM/WizardCoder-15B-V1.0", + "gpt-3.5-turbo", + ], + ], + "generations": [ + [ + "generation 1 1", + "generation 1 2", + "generation 1 3", + "generation 1 4", + ], + [ + "generation 2 1", + "generation 2 2", + "generation 2 3", + "generation 2 4", + ], + [ + "generation 3 1", + "generation 3 2", + "generation 3 3", + "generation 3 4", + ], + ], + "labelling_model": [ + "gpt-4-1106-preview", + "gpt-4-1106-preview", + "gpt-4-1106-preview", + ], + "labelling_prompt": [ + [ + { + "content": "Your role is to evaluate text quality based on given criteria.", + "role": "system", + }, + {"content": "content", "role": "user"}, + ], + [ + { + "content": "Your role is to evaluate text quality based on given criteria.", + "role": "system", + }, + {"content": "content", "role": "user"}, + ], + [ + { + "content": "Your role is to evaluate text quality based on given criteria.", + "role": "system", + }, + {"content": "content", "role": "user"}, + ], + ], + "raw_labelling_response": ["response", "response", "response"], + "rating": [ + [2.0, 5.0, 4.0, 5.0], + [2.0, 3.0, 1.0, 4.0], + [4.0, 3.0, 5.0, 3.0], + ], + "rationale": [ + ["rationale 1", "rationale 2", "rationale 3", "rationale 4"], + ["rationale 1", "rationale 2", "rationale 3", "rationale 4"], + ["rationale 1", "rationale 2", "rationale 3", "rationale 4"], + ], + } + ) + ds.task = UltraFeedbackTask.for_overall_quality() + return ds + + def test_dataset_save_to_disk(custom_dataset): with tempfile.TemporaryDirectory() as tmpdir: ds_name = Path(tmpdir) / "dataset_folder" @@ -74,3 +167,77 @@ def test_do_checkpoint( ctr += 1 assert ctr == expected == chk._total_checks + + +@pytest.mark.parametrize( + "strategy, chosen, rejected, chosen_model, rejected_model", + [ + ( + "random", + ["generation 2 4", "generation 3 3"], + ["generation 2 3", "generation 3 4"], + ["gpt-3.5-turbo", "WizardLM/WizardCoder-15B-V1.0"], + ["WizardLM/WizardCoder-15B-V1.0", "gpt-3.5-turbo"], + ), + ( + "worst", + ["generation 1 2", "generation 2 4", "generation 3 3"], + ["generation 1 1", "generation 2 3", "generation 3 2"], + [ + "WizardLM/WizardCoder-15B-V1.0", + "gpt-3.5-turbo", + "WizardLM/WizardCoder-15B-V1.0", + ], + [ + "argilla/notus-7b-v1", + "WizardLM/WizardCoder-15B-V1.0", + "ise-uiuc/Magicoder-S-DS-6.7B", + ], + ), + ], +) +def test_prepare_dataset( + sample_preference_dataset: CustomDataset, + strategy: str, + chosen: List[str], + rejected: List[str], + chosen_model: List[str], + rejected_model: List[str], +): + ds = prepare_dataset(sample_preference_dataset, strategy=strategy) + assert isinstance(ds, CustomDataset) + assert ds.column_names == [ + "prompt", + "chosen", + "rejected", + "rating_chosen", + "rating_rejected", + "chosen_model", + "rejected_model", + ] + for i, row in enumerate(ds): + assert row["chosen"] == chosen[i] + assert row["rejected"] == rejected[i] + assert row["chosen_model"] == chosen_model[i] + assert row["rejected_model"] == rejected_model[i] + + +def test_prepare_dataset_wrong_task(sample_preference_dataset: CustomDataset): + sample_preference_dataset.task = TextGenerationTask() + with pytest.raises( + ValueError, + match=re.escape( + "This functionality is currently implemented for `PreferenceTask` only." + ), + ): + prepare_dataset(sample_preference_dataset) + + +def test_dataset_wrong_strategy(sample_preference_dataset: CustomDataset): + with pytest.raises( + ValueError, + match=re.escape( + "Strategy `wrong_strategy` is not implemented, it must be one of: ('random', 'worst')" + ), + ): + prepare_dataset(sample_preference_dataset, strategy="wrong_strategy") From 1887f7f9abe8152f93e6507ff76f89b68b41682a Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 16 Jan 2024 21:37:16 +0100 Subject: [PATCH 3/9] Update tests with the new functionality --- tests/test_dataset.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ed04ceb8f1..978f32a2f9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -170,14 +170,23 @@ def test_do_checkpoint( @pytest.mark.parametrize( - "strategy, chosen, rejected, chosen_model, rejected_model", + "with_generation_model", + [True], +) +@pytest.mark.parametrize( + "strategy, chosen, rejected, chosen_model, rejected_model, keep_ties", [ ( "random", - ["generation 2 4", "generation 3 3"], - ["generation 2 3", "generation 3 4"], - ["gpt-3.5-turbo", "WizardLM/WizardCoder-15B-V1.0"], - ["WizardLM/WizardCoder-15B-V1.0", "gpt-3.5-turbo"], + ["generation 1 2", "generation 2 4", "generation 3 3"], + ["generation 1 1", "generation 2 3", "generation 3 4"], + [ + "WizardLM/WizardCoder-15B-V1.0", + "gpt-3.5-turbo", + "WizardLM/WizardCoder-15B-V1.0", + ], + ["argilla/notus-7b-v1", "WizardLM/WizardCoder-15B-V1.0", "gpt-3.5-turbo"], + True, ), ( "worst", @@ -193,6 +202,7 @@ def test_do_checkpoint( "WizardLM/WizardCoder-15B-V1.0", "ise-uiuc/Magicoder-S-DS-6.7B", ], + True, ), ], ) @@ -203,8 +213,16 @@ def test_prepare_dataset( rejected: List[str], chosen_model: List[str], rejected_model: List[str], + with_generation_model: bool, + keep_ties: bool, ): - ds = prepare_dataset(sample_preference_dataset, strategy=strategy) + if not with_generation_model: + sample_preference_dataset = sample_preference_dataset.remove_columns( + ["generation_model"] + ) + ds = prepare_dataset( + sample_preference_dataset, strategy=strategy, seed=42, keep_ties=keep_ties + ) assert isinstance(ds, CustomDataset) assert ds.column_names == [ "prompt", From a27f527ddc2c10e87909ecce61a1affe8d8c01ee Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 16 Jan 2024 21:37:59 +0100 Subject: [PATCH 4/9] Add initial prepare_dataset function to binarize datasets --- src/distilabel/utils/__init__.py | 4 ++ src/distilabel/utils/dataset.py | 120 +++++++++++++++++++++++++++---- 2 files changed, 110 insertions(+), 14 deletions(-) diff --git a/src/distilabel/utils/__init__.py b/src/distilabel/utils/__init__.py index 2598794f29..bcff2ceb87 100644 --- a/src/distilabel/utils/__init__.py +++ b/src/distilabel/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from distilabel.utils.dataset import prepare_dataset + +__all__ = ["prepare_dataset"] diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index 059897b978..ae23f896bd 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -13,12 +13,13 @@ # limitations under the License. import random +from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Literal, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, get_args import dill as pickle -from distilabel.tasks.preference.base import PreferenceTask +from distilabel.logger import get_logger if TYPE_CHECKING: from distilabel.dataset import CustomDataset @@ -27,6 +28,7 @@ TASK_FILE_NAME = "task.pkl" +logger = get_logger() BinarizationStrategies = Literal["random", "worst"] @@ -62,15 +64,35 @@ def load_task_from_disk(path: Path) -> "Task": def _binarize_dataset( dataset: "CustomDataset", - seed: int = 42, + seed: int = None, strategy: BinarizationStrategies = "random", + keep_ties: bool = True, + **kwargs: Any, ) -> "CustomDataset": + """Binarizes a distilabel dataset. + + Args: + dataset (CustomDataset): The distilabel dataset to binarize. + seed (int, optional): Random seed. Defaults to 42. + strategy (BinarizationStrategies, optional): Method to binarize the data. Defaults to "random". + keep_ties (bool, optional): + Whether to keep ties in case the binarization method generated the chosen + and rejected responses to have the same rating. Defaults to True. + kwargs: Extra parameters passed to `datasets.Dataset.map`. + + Raises: + ValueError: If the strategy is not implemented. + + Returns: + CustomDataset: Dataset binarized. + """ rating_column = "rating" responses_column = "generations" def binarize_random(example): random.seed(seed) + # First pick the highest rating prompt = example["input"] best_rating = max(example[rating_column]) best_response_idx = example[rating_column].index(best_rating) @@ -82,10 +104,23 @@ def binarize_random(example): example[responses_column].pop(best_response_idx) example["generation_model"].pop(best_response_idx) - # Select the random response - random_response = random.choice(example[responses_column]) - random_response_idx = example[responses_column].index(random_response) - random_rating = example[rating_column][random_response_idx] + # Then you pick the rejected from the list of candidates with lower scores. + example_lower = defaultdict(list) + for i, rating in enumerate(example[rating_column]): + if rating < best_rating: + example_lower[responses_column].append(example[responses_column][i]) + example_lower[rating_column].append(rating) + + # Otherwise you declare that a tie + if len(example_lower[rating_column]) == 0: + # In this case we don't have any response with a lower rating, so we just + # let the original example (we have a tie) + example_lower = example + + random_response = random.choice(example_lower[responses_column]) + random_response_idx = example_lower[responses_column].index(random_response) + random_rating = example_lower[rating_column][random_response_idx] + random_model = example["generation_model"][random_response_idx] binarized = { @@ -106,11 +141,13 @@ def binarize_worst(example): best_rating = max(example[rating_column]) best_response_idx = example[rating_column].index(best_rating) chosen_response = example[responses_column][best_response_idx] + chosen_model = example["generation_model"][best_response_idx] worst_rating = min(example[rating_column]) worst_response_idx = example[rating_column].index(worst_rating) worst_response = example[responses_column][worst_response_idx] + worst_model = example["generation_model"][worst_response_idx] binarized = { @@ -133,29 +170,71 @@ def binarize_worst(example): f"Strategy `{strategy}` is not implemented, it must be one of: {get_args(BinarizationStrategies)}" ) - return dataset.map(binarization_method).filter( - lambda example: example["rating_chosen"] != example["rating_rejected"] - ) + if "generation_model" not in dataset.column_names: + # Ensure generation model is found in the dataset, even if empty, to avoid + # erros when calling map + dataset = dataset.add_column( + "generation_model", [[""] * len(dataset[0]["generations"])] * len(dataset) + ) + + dataset = dataset.map(binarization_method, **kwargs) + + if not keep_ties: + dataset = dataset.filter( + lambda example: example["rating_chosen"] != example["rating_rejected"] + ) + return dataset def prepare_dataset( dataset: "CustomDataset", strategy: BinarizationStrategies = "random", - seed: int = 42, + seed: Optional[int] = None, + keep_ties: bool = True, + **kwargs: Any, ) -> "CustomDataset": - """Helper function to prepare a dataset for training assuming the standard formats. + """Helper function to prepare a distilabel dataset for training with the standard formats. + + Currently supports the `PreferenceTask`, and binarizes the responses assuming + one of two strategies: + + - `random`: Selects the *chosen* response based on the highest rating, and for the + *rejected* selects a random response from the remaining ones. Filters the examples in which + the chosen rating is equal to the rejected one. + - `worst`: Selects the *chosen* response based on the highest rating, and for the + *rejected* selects the response with the lowest rating. Filters the examples in which the + chosen rating is equal to the rejected one. + + Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) + for more information on binarizing a dataset to prepare it for DPO fine-tuning. Expected format for a dataset to be trained with DPO as defined in trl's [dpo trainer](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format). Args: - dataset (CustomDataset): Dataset with a PreferenceTask. + dataset (CustomDataset): + CustomDataset with a PreferenceTask to prepare for Direct Preference Optimization. strategy (BinarizationStrategies, optional): Strategy to binarize the data. Defaults to "random". + seed (int, optional): Seed for the random generator, in case of `random` strategy. Defaults to None. + keep_ties (bool, optional): + Whether to keep ties in case the binarization method generated the chosen + and rejected responses to have the same rating. Defaults to True. + kwargs: Extra parameters passed to `datasets.Dataset.map`. Returns: CustomDataset: Dataset formatted for training with DPO. + + Examples: + >>> from datasets import load_dataset + >>> from distilabel.tasks import UltraFeedbackTask + >>> import os + >>> dataset = load_dataset("argilla/DistiCoder-dpo", token=os.getenv("HF_API_TOKEN"), split="train") + >>> dataset.task = UltraFeedbackTask.for_instruction_following() + >>> dataset_binarized = prepare_dataset(dataset, strategy="worst") """ + from distilabel.tasks.preference.base import PreferenceTask + if not isinstance(dataset.task, PreferenceTask): raise ValueError( "This functionality is currently implemented for `PreferenceTask` only." @@ -171,9 +250,22 @@ def prepare_dataset( "raw_labelling_response", "rationale", ] + # Remove the rows for which there is no rating + initial_length = len(dataset) + dataset = dataset.filter(lambda example: example["rating"]) + if len(dataset) != initial_length: + logger.info( + f"Found {initial_length - len(dataset)} examples with no rating, removing them." + ) + + if len(dataset[0]["generations"]) < 2: + raise ValueError("The dataset must contain at least 2 generations per example.") - ds = _binarize_dataset(dataset, strategy=strategy, seed=42) + ds = _binarize_dataset( + dataset, strategy=strategy, seed=seed, keep_ties=keep_ties, **kwargs + ) + # Imported here to avoid circular imports from distilabel.dataset import CustomDataset ds = ds.remove_columns(remove_columns) From ae34b0ae9161e94d413d80adadd304528e1f9878 Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 16 Jan 2024 21:38:17 +0100 Subject: [PATCH 5/9] Update docs --- .../prepare_dataset_binarize_random.py | 14 ++++++++++++ .../prepare_dataset_binarize_worst.py | 14 ++++++++++++ docs/technical-reference/pipeline.md | 22 +++++++++++++++++++ 3 files changed, 50 insertions(+) create mode 100644 docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py create mode 100644 docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py diff --git a/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py b/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py new file mode 100644 index 0000000000..0da4ad65c1 --- /dev/null +++ b/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py @@ -0,0 +1,14 @@ +from datasets import load_dataset +from distilabel.tasks import JudgeLMTask +from distilabel.dataset import prepare_dataset + +dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train") +dataset.task = JudgeLMTask() +dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=True) +# >>> len(dataset) +# 12859 +# >>> len(dataset_binarized_random) +# 12817 +dataset_binarized_random = prepare_dataset(dataset, strategy="random", keep_ties=False) +# >>> len(dataset_binarized_random) +# 8850 \ No newline at end of file diff --git a/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py b/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py new file mode 100644 index 0000000000..87cd89c97e --- /dev/null +++ b/docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py @@ -0,0 +1,14 @@ +from datasets import load_dataset +from distilabel.tasks import JudgeLMTask +from distilabel.dataset import prepare_dataset + +dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train") +dataset.task = JudgeLMTask() +dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=True) +# >>> len(dataset) +# 12859 +# >>> len(dataset_binarized_random) +# 12817 +dataset_binarized_random = prepare_dataset(dataset, strategy="worst", keep_ties=False) +# >>> len(dataset_binarized_random) +# 8850 \ No newline at end of file diff --git a/docs/technical-reference/pipeline.md b/docs/technical-reference/pipeline.md index 4c8734ad8b..881be6e443 100644 --- a/docs/technical-reference/pipeline.md +++ b/docs/technical-reference/pipeline.md @@ -141,6 +141,28 @@ The dataset can be regenerated from the checkpoint by simply calling the `Custom And with the dataset regenerated we can easily call `push_to_argilla` on it to review it. +### Prepare datasets for fine-tuning + +The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format). + +`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*. + +By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are kept in the dataset, but those should be removed for fine-tuning. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information. + +!!! Binarization + + === "random" + + ```python + --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py" + ``` + + === "worst" + + ```python + --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py" + ``` + ## pipeline Considering recurring patterns in dataset creation, we can facilitate the process by utilizing the [`Pipeline`][distilabel.pipeline.Pipeline]. This is made simpler through the [`pipeline`][distilabel.pipeline.pipeline] function, which provides the necessary parameters for creating a `Pipeline`. From fa825e2a40d27e844fa1ccf1c8a502816d97d63f Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 17 Jan 2024 09:54:28 +0100 Subject: [PATCH 6/9] Update functions as per code review --- src/distilabel/utils/dataset.py | 89 +++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 33 deletions(-) diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index ae23f896bd..c1cac3f4ee 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import random from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, get_args import dill as pickle @@ -62,11 +63,45 @@ def load_task_from_disk(path: Path) -> "Task": return task +def _get_best_response( + example: Any, rating_column: str = "rating", responses_column: str = "generations" +) -> Tuple[str, int, str, str]: + """Helper function to get the best response from an example, this can be used + independent on the method to chose the rejected response. + + Also, it removes the best response from the example. + + Args: + example (Any): Each row in the dataset as passed when calling the map function on a datasets.Dataset. + rating_column (str, optional): + Column containing the rating in the CustomDataset. Defaults to "rating". + responses_column (str, optional): + Column containing the responses from a model in a CustomDataset. Defaults to "generations". + + Returns: + Tuple[str, int, str, str]: Contains the prompt, best rating, chosen response, and chosen model. + """ + # Pick the highest rating + prompt = example["input"] + best_rating = max(example[rating_column]) + best_response_idx = example[rating_column].index(best_rating) + chosen_response = example[responses_column][best_response_idx] + chosen_model = example["generation_model"][best_response_idx] + + # Remove best response + example[rating_column].pop(best_response_idx) + example[responses_column].pop(best_response_idx) + example["generation_model"].pop(best_response_idx) + return prompt, best_rating, chosen_response, chosen_model + + def _binarize_dataset( dataset: "CustomDataset", seed: int = None, strategy: BinarizationStrategies = "random", - keep_ties: bool = True, + keep_ties: bool = False, + rating_column: str = "rating", + responses_column: str = "generations", **kwargs: Any, ) -> "CustomDataset": """Binarizes a distilabel dataset. @@ -77,7 +112,7 @@ def _binarize_dataset( strategy (BinarizationStrategies, optional): Method to binarize the data. Defaults to "random". keep_ties (bool, optional): Whether to keep ties in case the binarization method generated the chosen - and rejected responses to have the same rating. Defaults to True. + and rejected responses to have the same rating. Defaults to False. kwargs: Extra parameters passed to `datasets.Dataset.map`. Raises: @@ -86,24 +121,16 @@ def _binarize_dataset( Returns: CustomDataset: Dataset binarized. """ - rating_column = "rating" - responses_column = "generations" + get_best_response = functools.partial( + _get_best_response, + rating_column=rating_column, + responses_column=responses_column, + ) def binarize_random(example): + prompt, best_rating, chosen_response, chosen_model = get_best_response(example) random.seed(seed) - # First pick the highest rating - prompt = example["input"] - best_rating = max(example[rating_column]) - best_response_idx = example[rating_column].index(best_rating) - chosen_response = example[responses_column][best_response_idx] - chosen_model = example["generation_model"][best_response_idx] - - # Remove best response - example[rating_column].pop(best_response_idx) - example[responses_column].pop(best_response_idx) - example["generation_model"].pop(best_response_idx) - # Then you pick the rejected from the list of candidates with lower scores. example_lower = defaultdict(list) for i, rating in enumerate(example[rating_column]): @@ -123,7 +150,7 @@ def binarize_random(example): random_model = example["generation_model"][random_response_idx] - binarized = { + return { "prompt": prompt, "chosen": chosen_response, "rejected": random_response, @@ -132,25 +159,16 @@ def binarize_random(example): "chosen_model": chosen_model, "rejected_model": random_model, } - return binarized def binarize_worst(example): - random.seed(seed) - - prompt = example["input"] - best_rating = max(example[rating_column]) - best_response_idx = example[rating_column].index(best_rating) - chosen_response = example[responses_column][best_response_idx] - - chosen_model = example["generation_model"][best_response_idx] + prompt, best_rating, chosen_response, chosen_model = get_best_response(example) worst_rating = min(example[rating_column]) worst_response_idx = example[rating_column].index(worst_rating) worst_response = example[responses_column][worst_response_idx] - worst_model = example["generation_model"][worst_response_idx] - binarized = { + return { "prompt": prompt, "chosen": chosen_response, "rejected": worst_response, @@ -159,7 +177,6 @@ def binarize_worst(example): "chosen_model": chosen_model, "rejected_model": worst_model, } - return binarized if strategy == "random": binarization_method = binarize_random @@ -190,7 +207,7 @@ def prepare_dataset( dataset: "CustomDataset", strategy: BinarizationStrategies = "random", seed: Optional[int] = None, - keep_ties: bool = True, + keep_ties: bool = False, **kwargs: Any, ) -> "CustomDataset": """Helper function to prepare a distilabel dataset for training with the standard formats. @@ -219,7 +236,7 @@ def prepare_dataset( seed (int, optional): Seed for the random generator, in case of `random` strategy. Defaults to None. keep_ties (bool, optional): Whether to keep ties in case the binarization method generated the chosen - and rejected responses to have the same rating. Defaults to True. + and rejected responses to have the same rating. Defaults to False. kwargs: Extra parameters passed to `datasets.Dataset.map`. Returns: @@ -262,7 +279,13 @@ def prepare_dataset( raise ValueError("The dataset must contain at least 2 generations per example.") ds = _binarize_dataset( - dataset, strategy=strategy, seed=seed, keep_ties=keep_ties, **kwargs + dataset, + strategy=strategy, + seed=seed, + keep_ties=keep_ties, + rating_column="rating", + responses_column="generations", + **kwargs, ) # Imported here to avoid circular imports From e4602d4e2c0a68737915cfe6a743ead343091d4e Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 17 Jan 2024 12:00:29 +0100 Subject: [PATCH 7/9] Add information on the binarization process for dpo --- docs/technical-reference/pipeline.md | 96 +++++++++++++++++++++------- src/distilabel/utils/dataset.py | 5 ++ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/docs/technical-reference/pipeline.md b/docs/technical-reference/pipeline.md index 881be6e443..fd04cda765 100644 --- a/docs/technical-reference/pipeline.md +++ b/docs/technical-reference/pipeline.md @@ -141,28 +141,6 @@ The dataset can be regenerated from the checkpoint by simply calling the `Custom And with the dataset regenerated we can easily call `push_to_argilla` on it to review it. -### Prepare datasets for fine-tuning - -The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format). - -`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*. - -By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are kept in the dataset, but those should be removed for fine-tuning. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information. - -!!! Binarization - - === "random" - - ```python - --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py" - ``` - - === "worst" - - ```python - --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py" - ``` - ## pipeline Considering recurring patterns in dataset creation, we can facilitate the process by utilizing the [`Pipeline`][distilabel.pipeline.Pipeline]. This is made simpler through the [`pipeline`][distilabel.pipeline.pipeline] function, which provides the necessary parameters for creating a `Pipeline`. @@ -194,3 +172,77 @@ The [CustomDataset][distilabel.dataset.CustomDataset] generated entirely by AI m ```python --8<-- "docs/snippets/technical-reference/pipeline/argilla.py" ``` + +## Prepare datasets for fine-tuning + +The preference datasets generated by distilabel out of the box contain all the raw information generated by the [`Pipeline`][distilabel.pipeline.Pipeline], but some processing is necessary in order to prepare the dataset for alignment or instruction fine-tuning, like for [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format) (initially we only cover the case for *DPO*). + +`distilabel` offers helper functions to prepare the [CustomDataset][distilabel.dataset.CustomDataset] for *DPO*. The current definition works for datasets labelled using `PreferenceTask`, and prepares them by *binarizing* the data. Go to the following section for an introduction of *dataset binarization*. + +By default the *ties* (rows for which the rating of the chosen and rejected responses are the same) are removed from the dataset, as that's expected for fine-tuning, but those can be kept in case it want's to be analysed. Take a look at [dataset.utils.prepare_dataset][distilabel.utils.dataset.prepare_dataset] for more information. + +!!! Binarization + + === "random" + + ```python + --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_random.py" + ``` + + === "worst" + + ```python + --8<-- "docs/snippets/technical-reference/pipeline/prepare_dataset_binarize_worst.py" + ``` + +### What's binarization? + +In the context of preference datasets (datasets for LLM instruction-tuning) one can come up with datasets formatted following the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) format (the same format one obtains from a `Pipeline` that labels a dataset with a [`PreferenceTask`][distilabel.tasks.preference.base.PreferenceTask]), where for a given instruction we can have multiple completions according to one or more models, rated either by humans or other LLMs. + +From distilabel, we would obtain from a labelling `Pipeline` a dataset with the following format: + +| input | generations | rating | +|:--------------------------------------------------------|:-----------------------------------|---------:| +| Generate an approximately fifteen-word sentence that... | [Midsummer House is a moderately..., Sure! Here's a sentence that...] | [9.0, 7.0] | + +Where each columns represents the following: + +- **input**: Input for the LLM to generate text. + +- **generations**: List of generations from the LLM (maybe an [LLMPool][distilabel.llm.base.LLMPool] with different models). + +- **rating**: A list of the ratings for each of the generations obtained by an LLM using one of the `PreferenceTasks`, like [JudgeLMTask][distilabel.tasks.preference.judgelm.JudgeLMTask] or [UltraFeedbackTask][distilabel.tasks.preference.ultrafeedback.UltraFeedbackTask] + +This dataset format contains all the raw information, but in order to use it in the common frameworks, the expected format is usually a prompt, a chosen and a rejected response to align the model with those preferences. + +We would want the following dataset format for fine-tuning: + +| prompt | chosen | rejected | +|:--------------------------------------------------------|:-----------------------------------|---------:| +| Generate an approximately fifteen-word sentence that... | Midsummer House is a moderately... | Sure! Here's a sentence that... | + +Take a look at this [explanation](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences#dataset-processing) for the binarization of *UltraFeedback* done to train [Notus-7B-v1](https://huggingface.co/argilla/notus-7b-v1). + +What does each column represents. + +- **prompt**: Instruction given to the model. + +- **chosen**: Response chosen. + +- **rejected**: Response rejected. + +This dataset processing is called binarization. In the context of `distilabel`, this transformation (dataset prepartion) is done by [`dataset.utils.prepare_dataset`][distilabel.utils.dataset.prepare_dataset], and given that the generated datasets contain additional information, one can also see the following additional columns: + +| prompt | chosen | rejected | rating_chosen | rating_rejected | chosen_model | rejected_model | +|:--------------------------------------------------------|:-----------------------------------|:--------------------------------|----------------:|------------------:|:---------------|:-----------------| +| Generate an approximately fifteen-word sentence that... | Midsummer House is a moderately... | Sure! Here's a sentence that... | 9 | 7 | | | + +- **rating_chosen**: Rating of the chosen instruction. + +- **rating_rejected**: Rating of the rejected instruction. + +- **chosen_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the chosen instruction. + +- **rejected_model**: (*Optional*, only returned if the dataset contains it, otherwise it's a null string like here) The model used to generate the rejected instruction. + +Need more information? Take a look at [argilla/ultrafeedback-binarized-preferences](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences) to get an idea of how [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) can be binarized to prepare it for *DPO*. diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index c1cac3f4ee..255eb89c0a 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -228,6 +228,11 @@ def prepare_dataset( Expected format for a dataset to be trained with DPO as defined in trl's [dpo trainer](https://huggingface.co/docs/trl/main/en/dpo_trainer#expected-dataset-format). + Note: + Take a look at the + [Prepare datasets for fine-tuning](https://distilabel.argilla.io/latest/technical-reference/pipeline/#prepare-datasets-for-fine-tuning) + section in the Concept guides for more information on the binarization process. + Args: dataset (CustomDataset): CustomDataset with a PreferenceTask to prepare for Direct Preference Optimization. From d0a029de3eb92ac2d348cf1f6ea31f2ea69a00a6 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 17 Jan 2024 14:03:43 +0100 Subject: [PATCH 8/9] Update to openai chat messages format --- docs/technical-reference/pipeline.md | 10 +++-- src/distilabel/utils/dataset.py | 20 ++++++++-- tests/test_dataset.py | 60 ++++++++++++++++++++++++++-- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/docs/technical-reference/pipeline.md b/docs/technical-reference/pipeline.md index fd04cda765..ed0acf6728 100644 --- a/docs/technical-reference/pipeline.md +++ b/docs/technical-reference/pipeline.md @@ -219,7 +219,7 @@ We would want the following dataset format for fine-tuning: | prompt | chosen | rejected | |:--------------------------------------------------------|:-----------------------------------|---------:| -| Generate an approximately fifteen-word sentence that... | Midsummer House is a moderately... | Sure! Here's a sentence that... | +| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] | Take a look at this [explanation](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences#dataset-processing) for the binarization of *UltraFeedback* done to train [Notus-7B-v1](https://huggingface.co/argilla/notus-7b-v1). @@ -227,15 +227,17 @@ What does each column represents. - **prompt**: Instruction given to the model. -- **chosen**: Response chosen. +- **chosen**: Response chosen following the OpenAI format. -- **rejected**: Response rejected. +- **rejected**: Response rejected following the OpenAI format. + +We refer to the [OpenAI's chat format](https://platform.openai.com/docs/guides/text-generation) for more information on the chosen/rejected format. This dataset processing is called binarization. In the context of `distilabel`, this transformation (dataset prepartion) is done by [`dataset.utils.prepare_dataset`][distilabel.utils.dataset.prepare_dataset], and given that the generated datasets contain additional information, one can also see the following additional columns: | prompt | chosen | rejected | rating_chosen | rating_rejected | chosen_model | rejected_model | |:--------------------------------------------------------|:-----------------------------------|:--------------------------------|----------------:|------------------:|:---------------|:-----------------| -| Generate an approximately fifteen-word sentence that... | Midsummer House is a moderately... | Sure! Here's a sentence that... | 9 | 7 | | | +| Generate an approximately fifteen-word sentence that... | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': 'Midsummer House is a moderately...', 'role': 'assistant'}] | [{'content': 'Generate an approximately...', 'role': 'user'}, {'content': ' Sure! Here\'s a sentence that...', 'role': 'assistant'}] | 9 | 7 | | | - **rating_chosen**: Rating of the chosen instruction. diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index 255eb89c0a..0c6cf70cee 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -95,6 +95,18 @@ def _get_best_response( return prompt, best_rating, chosen_response, chosen_model +def _format_message(prompt: str, response: str) -> list[dict[str, str]]: + """Helper function to format the messages (chosen/rejected) in OpenAI format. + + Returns: + message: List of dictionaries with the OpenAI format. + """ + return [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ] + + def _binarize_dataset( dataset: "CustomDataset", seed: int = None, @@ -152,8 +164,8 @@ def binarize_random(example): return { "prompt": prompt, - "chosen": chosen_response, - "rejected": random_response, + "chosen": _format_message(prompt, chosen_response), + "rejected": _format_message(prompt, random_response), "rating_chosen": int(best_rating), "rating_rejected": int(random_rating), "chosen_model": chosen_model, @@ -170,8 +182,8 @@ def binarize_worst(example): return { "prompt": prompt, - "chosen": chosen_response, - "rejected": worst_response, + "chosen": _format_message(prompt, chosen_response), + "rejected": _format_message(prompt, worst_response), "rating_chosen": int(best_rating), "rating_rejected": int(worst_rating), "chosen_model": chosen_model, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 978f32a2f9..682ad4bfca 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -178,8 +178,34 @@ def test_do_checkpoint( [ ( "random", - ["generation 1 2", "generation 2 4", "generation 3 3"], - ["generation 1 1", "generation 2 3", "generation 3 4"], + [ + [ + {"content": "input 1", "role": "user"}, + {"content": "generation 1 2", "role": "assistant"}, + ], + [ + {"content": "input 2", "role": "user"}, + {"content": "generation 2 4", "role": "assistant"}, + ], + [ + {"content": "input 3", "role": "user"}, + {"content": "generation 3 3", "role": "assistant"}, + ], + ], + [ + [ + {"content": "input 1", "role": "user"}, + {"content": "generation 1 1", "role": "assistant"}, + ], + [ + {"content": "input 2", "role": "user"}, + {"content": "generation 2 3", "role": "assistant"}, + ], + [ + {"content": "input 3", "role": "user"}, + {"content": "generation 3 4", "role": "assistant"}, + ], + ], [ "WizardLM/WizardCoder-15B-V1.0", "gpt-3.5-turbo", @@ -190,8 +216,34 @@ def test_do_checkpoint( ), ( "worst", - ["generation 1 2", "generation 2 4", "generation 3 3"], - ["generation 1 1", "generation 2 3", "generation 3 2"], + [ + [ + {"content": "input 1", "role": "user"}, + {"content": "generation 1 2", "role": "assistant"}, + ], + [ + {"content": "input 2", "role": "user"}, + {"content": "generation 2 4", "role": "assistant"}, + ], + [ + {"content": "input 3", "role": "user"}, + {"content": "generation 3 3", "role": "assistant"}, + ], + ], + [ + [ + {"content": "input 1", "role": "user"}, + {"content": "generation 1 1", "role": "assistant"}, + ], + [ + {"content": "input 2", "role": "user"}, + {"content": "generation 2 3", "role": "assistant"}, + ], + [ + {"content": "input 3", "role": "user"}, + {"content": "generation 3 2", "role": "assistant"}, + ], + ], [ "WizardLM/WizardCoder-15B-V1.0", "gpt-3.5-turbo", From 64f28f1c4069dd884775eefefa20b9890964cf30 Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 19 Jan 2024 11:01:10 +0100 Subject: [PATCH 9/9] Correct types for python 3.8 --- src/distilabel/utils/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distilabel/utils/dataset.py b/src/distilabel/utils/dataset.py index 0c6cf70cee..4417a538b3 100644 --- a/src/distilabel/utils/dataset.py +++ b/src/distilabel/utils/dataset.py @@ -16,7 +16,7 @@ import random from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, get_args +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, get_args import dill as pickle @@ -95,7 +95,7 @@ def _get_best_response( return prompt, best_rating, chosen_response, chosen_model -def _format_message(prompt: str, response: str) -> list[dict[str, str]]: +def _format_message(prompt: str, response: str) -> List[Dict[str, str]]: """Helper function to format the messages (chosen/rejected) in OpenAI format. Returns: