Skip to content

Commit

Permalink
it runs
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Sep 29, 2024
1 parent fc2865e commit 2c054fe
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 28 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ For example, the following command does both:
```
rewardbench --model vwxyzjn/reward_modeling__EleutherAI_pythia-14m --batch_size 128 --tokenizer=EleutherAI/pythia-14m --push_results_to_hub --upload_model_metadata_to_hf --chat_template raw
```
Or, for an instruction dataset:
```
rewardbench --model vwxyzjn/reward_modeling__EleutherAI_pythia-14m --dataset HuggingFaceH4/no_robots --split test --batch_size 128 --tokenizer=EleutherAI/pythia-14m --push_results_to_hub --chat_template raw
```
(Note that chat templates only need to be specififed for older models)

The key commands are:
* `--push_results_to_hub` which uploads a dataset of scores and correctness.
* ` --upload_model_metadata_to_hf` adds results directly to model.

For an example of a model with accuracy metadata, look [here](https://huggingface.co/vwxyzjn/rm_zephyr_new).
For an example of the outputs from a preference dataset, look [here]().
For an example of the outputs from a preference dataset, look [here](https://huggingface.co/datasets/natolambert/rewardbench_eval_2339270924_2339270924), and for instructions, look [here](https://huggingface.co/datasets/natolambert/rewardbench_eval_0329290924).

## Full Installation
To install from source, please install `torch` on your system, and then install the following requirements.
Expand Down
8 changes: 4 additions & 4 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def main():
rewardbench(*parser.parse_args_into_dataclasses())


# Structure eeded to accomodate HuggingFace Args with CLI binding
# Secondary function structure needed to accomodate HuggingFace Args with CLI binding
def rewardbench(args: Args):
if args.wandb_run is not None:
wandb_run = wandb.Api().run(args.wandb_run)
Expand Down Expand Up @@ -471,7 +471,7 @@ def rewardbench(args: Args):

combined_data = {
"prompt": dataset["prompt"], # Assuming `prompts` is a list of prompts matching scores
"results": results,
"results": [item for sublist in results for item in sublist],
}

# Consolidate chosen and rejected scores along with prompts and texts
Expand All @@ -482,7 +482,7 @@ def rewardbench(args: Args):
combined_data["text_rejected"] = dataset["text_rejected"]
# or take instruction
else:
combined_data["messages"] = dataset["messages"]
combined_data["text"] = dataset["text"]

# Save combined scores and metadata to JSONL
scores_output_path = os.path.join(args.output_dir, f"{args.model}_outputs.jsonl")
Expand All @@ -491,7 +491,7 @@ def rewardbench(args: Args):
# Upload to HF
if args.push_results_to_hub:
hf_repo = push_results_to_hub(args, combined_data)
logger.info(f"Pushed results to Hugging Face Hub for {hf_repo}")
logger.info(f"Pushed results to Hugging Face Hub for https://huggingface.co/datasets/{hf_repo}")

############################
# the rest is just for preferences (accuracies)
Expand Down
64 changes: 41 additions & 23 deletions rewardbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,25 +197,14 @@ def load_and_process_dataset(
datasets_to_combine = [dataset[split] for split in available_splits]
dataset = concatenate_datasets(datasets_to_combine)

# Handle column renaming
# Handle column renaming to track prompts
if "question" in dataset.column_names and "prompt" not in dataset.column_names:
dataset = dataset.rename_column("question", "prompt")
if "input" in dataset.column_names and "prompt" not in dataset.column_names:
dataset = dataset.rename_column("input", "prompt")

features = dataset.features

def process_preference_data(example):
example["prompt"] = example["chosen"][:-1]
example["chosen"] = example["chosen"][-1]["content"]
example["rejected"] = example["rejected"][-1]["content"]
return example

def process_instruction_data(example):
messages = example["messages"]
example["prompt"] = messages[0]
return example

# Determine if it's preference data or instruction data
has_preference_data = "chosen" in dataset.column_names and "rejected" in dataset.column_names
has_instruction_data = "messages" in dataset.column_names
Expand All @@ -239,6 +228,18 @@ def process_instruction_data(example):
" columns for preference data, or a 'messages' column for instruction data."
)

# Process the data for input to RM
def process_preference_data(example):
example["prompt"] = example["chosen"][:-1]
example["chosen"] = example["chosen"][-1]["content"]
example["rejected"] = example["rejected"][-1]["content"]
return example

def process_instruction_data(example):
messages = example["messages"]
example["prompt"] = messages[0]["content"]
return example

if is_preference_data:
if "prompt" not in dataset.column_names or not isinstance(features["prompt"], list):
dataset = dataset.map(
Expand Down Expand Up @@ -272,13 +273,13 @@ def process_instruction_data(example):
logger.info("*** Preparing dataset with FastChat ***")
dataset = dataset.map(
prepare_dialogue,
fn_kwargs={"dialogue_template": conv},
fn_kwargs={"dialogue_template": conv, "ift": not is_preference_data},
num_proc=8,
load_from_cache_file=False,
)

# Remove excess data
keep_columns = ["prompt", "text_chosen", "text_rejected"] if is_preference_data else ["prompt", "messages"]
keep_columns = ["prompt", "text_chosen", "text_rejected"] if is_preference_data else ["prompt", "text"]
all_cols = dataset.column_names
dataset = dataset.remove_columns([c for c in all_cols if c not in keep_columns])
return dataset
Expand Down Expand Up @@ -594,11 +595,13 @@ def prepare_dialogue_from_tokenizer(
)
example["prompt"] = temp_prompt
elif ift:
# TODO adapt this for DPO models with tokenize_row function
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["input"]},
]
if "messages" in example:
messages = example["messages"]
else:
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["input"]},
]
example["text"] = tokenizer.apply_chat_template(
messages,
tokenize=False,
Expand Down Expand Up @@ -669,14 +672,29 @@ def prepare_dialogue(
if isinstance(example["prompt"], list):
example["prompt"] = example["prompt"][0]

# get prompt
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
]
temp_prompt = dialogue_template.get_prompt()
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
[dialogue_template.roles[1], example["input"]],
]

# get messages
if "messages" in example:
# convert to FastChat format (list of list)
# original format:
# [
# {"role": "user", "content": example["prompt"]},
# {"role": "assistant", "content": example["rejected"]},
# ]
dialogue_template.messages = []
for i, line in enumerate(example["messages"]):
role = dialogue_template.roles[0] if i % 2 == 0 else dialogue_template.roles[1]
dialogue_template.messages.append([role, line["content"]])
else:
dialogue_template.messages = [
[dialogue_template.roles[0], example["prompt"]],
[dialogue_template.roles[1], example["input"]],
]
example["text"] = dialogue_template.get_prompt()
example["prompt"] = temp_prompt # needed for DPO

Expand Down
22 changes: 22 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def test_prepare_dialogue_from_tokenizer_ift(self):
desired_text = "<|user|>\nWhat are different drawers I should have for clothes?<|endoftext|>\n<|assistant|>\nUtensils!<|endoftext|>\n" # noqa
assert prepared["text"] == desired_text

def test_prepare_dialogue_from_tokenizer_messages_ift(self):
example = {}
example["messages"] = [
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am a bot."},
]
example["prompt"] = "Who are you?"
prepared = prepare_dialogue_from_tokenizer(example, self.tokenizer, ift=True)
desired_text = "<|user|>\nWho are you?<|endoftext|>\n<|assistant|>\nI am a bot.<|endoftext|>\n"
assert prepared["text"] == desired_text

def test_prepare_dialogue_single_turn(self):
example = {}
example["prompt"] = "What are different drawers I should have for clothes?"
Expand Down Expand Up @@ -126,6 +137,17 @@ def test_prepare_dialogue_ift(self):
desired_text = "<|user|>\nWhat are different drawers I should have for clothes?\n<|assistant|>\nUtensils!\n"
assert prepared["text"] == desired_text

def test_prepare_dialogue_messages_ift(self):
example = {}
example["messages"] = [
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am a bot."},
]
example["prompt"] = "Who are you?"
prepared = prepare_dialogue(example, self.conv, ift=True)
desired_text = "<|user|>\nWho are you?\n<|assistant|>\nI am a bot.\n"
assert prepared["text"] == desired_text


class DatasetTest(unittest.TestCase):
def test_core_dataset_lens(self):
Expand Down

0 comments on commit 2c054fe

Please sign in to comment.