Skip to content

Commit

Permalink
reordering init
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Oct 13, 2024
1 parent c3eb08e commit 898f621
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 49 deletions.
32 changes: 16 additions & 16 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -220,22 +220,22 @@ steps_preference_example = {"prompt": "Two apples and one orange cost 1.5 euros.

Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.

| Trainer | Expected dataset type |
| ----------------------- | ------------------------------------------------------- |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`StepwiseRewardTrainer`] | [Name to find] |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
| Trainer | Expected dataset type |
| ------------------------- | ------------------------------------------------------- |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`StepwiseRewardTrainer`] | [Name to find] |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |

<Tip>

Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/stepwise_reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

if __name__ == "__main__":
parser = HfArgumentParser((RewardScriptArguments, StepwiseRewardConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
Expand Down Expand Up @@ -104,7 +104,7 @@
##############
# Load dataset
##############
dataset = load_dataset(args.dataset_name, data_dir="conversational_stepwise_preference")
dataset = load_dataset(script_args.dataset_name)

##########
# Training
Expand All @@ -113,8 +113,8 @@
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
peft_config=get_peft_config(model_config),
)
trainer.train()
Expand All @@ -130,4 +130,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=args.dataset_name)
trainer.push_to_hub(dataset_name=script_args.dataset_name)
37 changes: 13 additions & 24 deletions tests/test_stepwise_reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForTokenClassification, AutoTokenizer, EvalPrediction
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import StepwiseRewardConfig, StepwiseRewardTrainer, maybe_apply_chat_template
from trl.trainer import compute_accuracy
from trl.trainer.stepwise_reward_trainer import _tokenize
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if is_peft_available():
Expand All @@ -31,10 +33,14 @@

class StepwiseRewardTrainerTester(unittest.TestCase):
def setUp(self):
self.model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.model = AutoModelForTokenClassification.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
self.model = AutoModelForTokenClassification.from_pretrained(self.model_id, num_labels=2)
self.tokenizer.pad_token = self.tokenizer.eos_token

# Ensure the tokenizer has a chat template
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def test_token_level_accuracy(self):
dummy_eval_predictions = EvalPrediction(
Expand All @@ -44,11 +50,10 @@ def test_token_level_accuracy(self):
accuracy = compute_accuracy(dummy_eval_predictions)
self.assertEqual(accuracy["accuracy"], 0.5)

def test_preprocessing_conversational(self):
@parameterized.expand(["conversational_stepwise_preference", "standard_stepwise_preference"])
def test_preprocessing(self, dataset_type):
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset(
"trl-internal-testing/zen", "conversational_stepwise_preference", split="train"
)
dummy_dataset = load_dataset("trl-internal-testing/zen", dataset_type, split="train")
training_args = StepwiseRewardConfig(output_dir=tmp_dir, report_to="none", max_length=512)
trainer = StepwiseRewardTrainer(
model=self.model,
Expand All @@ -60,23 +65,7 @@ def test_preprocessing_conversational(self):
dummy_dataset = dummy_dataset.map(
_tokenize,
batched=True,
fn_kwargs={"tokenizer": self.tokenizer, "max_length": 512, "post_step_separator": "\n"},
)
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])

def test_preprocessing_standard(self):
# No chat template, so we load a fresh tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_preference", split="train")
training_args = StepwiseRewardConfig(output_dir=tmp_dir, report_to="none", max_length=512)
trainer = StepwiseRewardTrainer(
model=self.model, args=training_args, tokenizer=tokenizer, train_dataset=dummy_dataset
)
dummy_dataset = self.dummy_dataset.map(
_tokenize,
batched=True,
fn_kwargs={"tokenizer": tokenizer, "max_length": 512, "post_step_separator": "\n"},
fn_kwargs={"tokenizer": self.tokenizer, "max_length": 512, "step_separator": "\n"},
)
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])

Expand Down
4 changes: 2 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@
"RandomRankJudge",
"RewardConfig",
"RewardTrainer",
"StepwiseRewardConfig",
"StepwiseRewardTrainer",
"RLOOConfig",
"RLOOTrainer",
"SFTConfig",
"SFTTrainer",
"StepwiseRewardConfig",
"StepwiseRewardTrainer",
"WinRateCallback",
"XPOConfig",
"XPOTrainer",
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
"ppov2_trainer": ["PPOv2Trainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"stepwise_reward_config": ["StepwiseRewardConfig"],
"stepwise_reward_trainer": ["StepwiseRewardTrainer"],
"rloo_config": ["RLOOConfig"],
"rloo_trainer": ["RLOOTrainer"],
"sft_config": ["SFTConfig"],
"sft_trainer": ["SFTTrainer"],
"stepwise_reward_config": ["StepwiseRewardConfig"],
"stepwise_reward_trainer": ["StepwiseRewardTrainer"],
"utils": [
"AdaptiveKLController",
"ConstantLengthDataset",
Expand Down

0 comments on commit 898f621

Please sign in to comment.