Skip to content

Commit

Permalink
Merge branch 'main' into cgpo_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop authored Oct 21, 2024
2 parents eab05d7 + 84dab85 commit 7231abb
Show file tree
Hide file tree
Showing 15 changed files with 628 additions and 846 deletions.
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ dataset = dataset.remove_columns(["chosen", "rejected"])

### From explicit to implicit prompt preference dataset

To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.

```python
from datasets import Dataset
Expand Down
4 changes: 4 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,7 @@ dpo_trainer = DPOTrainer(
## DPOConfig

[[autodoc]] DPOConfig

## PreferenceCollator

[[autodoc]] trainer.dpo_trainer.PreferenceCollator
12 changes: 0 additions & 12 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"""

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor

Expand Down Expand Up @@ -106,17 +105,6 @@
################
dataset = load_dataset(script_args.dataset_name)

def process(row):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = dataset.map(process, num_proc=training_args.dataset_num_proc)

################
# Training
################
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_sft_cli():
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-lib/ultrafeedback_binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
Expand Down
260 changes: 109 additions & 151 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import tempfile
import unittest
from unittest.mock import MagicMock

import numpy as np
import pytest
Expand All @@ -27,172 +28,129 @@
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
)
from transformers.testing_utils import require_bitsandbytes, require_peft

from trl import DPOConfig, DPOTrainer, FDivergenceType
from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens

from .testing_utils import require_no_wandb


class TestBuildTokenizedAnswer(unittest.TestCase):
class TestTokenizeRow(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token

def test_basic_functionality(self):
prompt = "Hello, how are you?"
answer = "I'm doing well, thank you!"

result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)

self.assertIn("prompt_input_ids", result)
self.assertIn("prompt_attention_mask", result)
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)

self.assertEqual(len(result["prompt_input_ids"]), len(result["prompt_attention_mask"]))
self.assertEqual(len(result["input_ids"]), len(result["attention_mask"]))

decoded_prompt = self.tokenizer.decode(result["prompt_input_ids"])
self.assertTrue(prompt in decoded_prompt)

decoded_answer = self.tokenizer.decode(result["input_ids"])
self.assertTrue(answer in decoded_answer)

def test_with_processor(self):
def mock_processor(text, images=None, add_special_tokens=True):
return {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])}

prompt = "Describe this image:"
answer = "A beautiful sunset over the ocean."

result = _build_tokenized_answer(prompt, answer, processor=mock_processor)

self.assertIn("prompt_input_ids", result)
self.assertIn("prompt_attention_mask", result)
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)

self.assertEqual(result["prompt_input_ids"], [1, 2, 3])
self.assertEqual(result["prompt_attention_mask"], [1, 1, 1])

def test_token_merging(self):
prompt = "The quick brown"
answer = " fox jumps over the lazy dog."

result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)

full_text = prompt + answer
full_tokenized = self.tokenizer(full_text, add_special_tokens=False)

self.assertEqual(result["prompt_input_ids"] + result["input_ids"], full_tokenized["input_ids"])
# Set up the mock tokenizer with specific behaviors
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
self.tokenizer.bos_token_id = 0
self.tokenizer.eos_token_id = 2

# Define mock return values for the tokenizer's 'input_ids' for the different text inputs
self.tokenizer.return_value = {
"input_ids": {"The sky is": [464, 6766, 318], " blue": [4171], " green": [4077]}
}

def test_vision_model(self):
def mock_vision_processor(text, images=None, add_special_tokens=True):
return {
"input_ids": torch.tensor([[1, 2, 3]]),
"attention_mask": torch.tensor([[1, 1, 1]]),
"pixel_values": torch.rand(1, 3, 224, 224),
"pixel_attention_mask": torch.ones(1, 224, 224),
# Define tokenizer behavior when called
def mock_tokenizer_call(text, add_special_tokens):
token_map = {
"The sky is": {"input_ids": [464, 6766, 318]},
" blue": {"input_ids": [4171]},
" green": {"input_ids": [4077]},
}
return token_map[text]

prompt = "Describe this image:"
answer = "A cat sitting on a windowsill."
self.tokenizer.side_effect = mock_tokenizer_call

result = _build_tokenized_answer(prompt, answer, processor=mock_vision_processor)
def test_tokenize_row_no_truncation_no_special_tokens(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

self.assertIn("prompt_pixel_values", result)
self.assertIn("prompt_pixel_attention_mask", result)
self.assertTrue(torch.is_tensor(result["prompt_pixel_values"]))
self.assertTrue(torch.is_tensor(result["prompt_pixel_attention_mask"]))
# Call the method with no truncation and no special tokens
result = DPOTrainer.tokenize_row(
features=features,
processing_class=self.tokenizer,
max_prompt_length=None,
max_completion_length=None,
add_special_tokens=False,
)

# Assert the correct output without truncation or special tokens
self.assertEqual(
result,
{
"prompt_input_ids": [464, 6766, 318],
"chosen_input_ids": [4171, 2], # eos_token added
"rejected_input_ids": [4077, 2], # eos_token added
},
)

class TestTruncateTokens(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as tmp_dir:
self.training_args = DPOConfig(
max_length=20, max_prompt_length=10, truncation_mode="keep_start", output_dir=tmp_dir
)
def test_tokenize_row_with_truncation(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

# Call the method with truncation
result = DPOTrainer.tokenize_row(
features=features,
processing_class=self.tokenizer,
max_prompt_length=2,
max_completion_length=1,
add_special_tokens=False,
)

def test_truncate_tokens(self):
chosen_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(10)),
"attention_mask": [1] * 10,
}
]
rejected_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(12)),
"attention_mask": [1] * 12,
}
]
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]

_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.training_args)

# Check if prompt is truncated correctly
self.assertEqual(len(chosen_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(chosen_tokens[0]["prompt_attention_mask"]), 10)
self.assertEqual(len(rejected_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(rejected_tokens[0]["prompt_attention_mask"]), 10)
self.assertEqual(len(prompt_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(prompt_tokens[0]["prompt_attention_mask"]), 10)

# Check if responses are truncated correctly
self.assertEqual(len(chosen_tokens[0]["input_ids"]), 10)
self.assertEqual(len(chosen_tokens[0]["attention_mask"]), 10)
self.assertEqual(len(rejected_tokens[0]["input_ids"]), 10)
self.assertEqual(len(rejected_tokens[0]["attention_mask"]), 10)

def test_truncation_mode_keep_end(self):
self.training_args.truncation_mode = "keep_end"
chosen_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(15, 25)),
"attention_mask": [1] * 10,
}
]
rejected_tokens = [
# Assert the correct output with truncation applied
self.assertEqual(
result,
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(15, 28)),
"attention_mask": [1] * 13,
}
]
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]

_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.training_args)
"prompt_input_ids": [6766, 318], # truncated to the last 2 tokens
"chosen_input_ids": [4171], # truncated to 1 token
"rejected_input_ids": [4077], # truncated to 1 token
},
)

# Check if prompt is truncated correctly from the end
self.assertEqual(prompt_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(prompt_tokens[0]["prompt_attention_mask"], [1] * 10)
def test_tokenize_row_with_special_tokens(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

# Call the method with special tokens
result = DPOTrainer.tokenize_row(
features=features,
processing_class=self.tokenizer,
max_prompt_length=None,
max_completion_length=None,
add_special_tokens=True,
)

# Check if chosen tokens are truncated correctly
self.assertEqual(chosen_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(chosen_tokens[0]["prompt_attention_mask"], [1] * 10)
self.assertEqual(chosen_tokens[0]["input_ids"], list(range(15, 25)))
self.assertEqual(chosen_tokens[0]["attention_mask"], [1] * 10)
# Assert the correct output with special tokens added
self.assertEqual(
result,
{
"prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added
"chosen_input_ids": [4171, 2], # eos_token added
"rejected_input_ids": [4077, 2], # eos_token added
},
)

# Check if rejected tokens are truncated correctly
self.assertEqual(rejected_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(rejected_tokens[0]["prompt_attention_mask"], [1] * 10)
self.assertEqual(rejected_tokens[0]["input_ids"], list(range(15, 25)))
self.assertEqual(rejected_tokens[0]["attention_mask"], [1] * 10)
def test_tokenize_row_with_truncation_and_special_tokens(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

# Call the method with both truncation and special tokens
result = DPOTrainer.tokenize_row(
features=features,
processing_class=self.tokenizer,
max_prompt_length=4,
max_completion_length=1,
add_special_tokens=True,
)

def test_invalid_truncation_mode(self):
self.training_args.truncation_mode = "invalid_mode"
with self.assertRaises(ValueError):
_truncate_tokens([], [], [], self.training_args)
# Assert the correct output with both truncation and special tokens
self.assertEqual(
result,
{
"prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token
"chosen_input_ids": [4171], # truncated to 1 token
"rejected_input_ids": [4077], # truncated to 1 token
},
)


class DPOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -461,9 +419,9 @@ def test_dpo_trainer_padding_token_is_none(self):

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
r"before instantiating the trainer.",
):
trainer = DPOTrainer(
model=self.model,
Expand Down Expand Up @@ -498,9 +456,9 @@ def test_dpo_trainer_w_dataset_num_proc(self):

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
r"before instantiating the trainer.",
):
trainer = DPOTrainer(
model=self.model,
Expand Down Expand Up @@ -1139,7 +1097,7 @@ def test_vdpo_trainer(self, model_id):
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=512,
max_prompt_length=128,
max_prompt_length=512,
remove_unused_columns=False,
report_to="none",
)
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return SequentialSampler(self.train_dataset)

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1407,7 +1408,7 @@ def evaluation_loop(
"prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
Expand Down
Loading

0 comments on commit 7231abb

Please sign in to comment.