From e42d40fb73616eee7de4713801c51f81e2c4b969 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 1 Oct 2024 14:04:16 -0700 Subject: [PATCH] Add better logging and functionality with instructions to CLI (#193) --- README.md | 29 +++ rewardbench/__init__.py | 4 +- rewardbench/rewardbench.py | 489 +++++++++++++++++++++++++------------ rewardbench/utils.py | 160 +++++++----- tests/test_data.py | 22 ++ tests/test_package.py | 16 +- 6 files changed, 497 insertions(+), 223 deletions(-) diff --git a/README.md b/README.md index 6e43e253..b3dfd3ce 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ The two primary scripts to generate results (more in `scripts/`): ## Quick Usage RewardBench let's you quickly evaluate any reward model on any preference set. +It also will detect if a instruction dataset is passed (by checking for not having `chosen`/`rejected`, and having `messages`) -- for these, just a model outputs are logged (not accuracy). + To install for quick usage, install with pip as: ``` pip install rewardbench @@ -70,6 +72,33 @@ rewardbench-gen --model={} For more information, see `scripts/run_generative.py`. The extra requirement for local models is VLLM and the requesite API for API models (OpenAI, Anthropic, and Together are supported). +### Logging + +The CLI comes with multiple advanced saving features for **model outputs** and **accuracy scores**. +These can be tied in metadata to reward models you own or uploaded as separate datasets to HuggingFace, such as for rejection sampling. +For example, the following command does both: +``` +rewardbench --model vwxyzjn/reward_modeling__EleutherAI_pythia-14m --batch_size 128 --tokenizer=EleutherAI/pythia-14m --push_results_to_hub --upload_model_metadata_to_hf --chat_template raw +``` +Or, for an instruction dataset: +``` +rewardbench --model vwxyzjn/reward_modeling__EleutherAI_pythia-14m --dataset HuggingFaceH4/no_robots --split test --batch_size 128 --tokenizer=EleutherAI/pythia-14m --push_results_to_hub --chat_template raw +``` +(Note that chat templates only need to be specififed for older models) + +The key commands are: +* `--push_results_to_hub` which uploads a dataset of scores and correctness. +* ` --upload_model_metadata_to_hf` adds results directly to model. + +For an example of a model with accuracy metadata, look [here](https://huggingface.co/vwxyzjn/rm_zephyr_new). +For an example of the outputs from a preference dataset, look [here](https://huggingface.co/datasets/natolambert/rewardbench_eval_2339270924_2339270924), and for instructions, look [here](https://huggingface.co/datasets/natolambert/rewardbench_eval_0329290924). + +This currently only works with DPO models for preference datasets, such as: +``` +rewardbench --model Qwen/Qwen1.5-0.5B-Chat --ref_model Qwen/Qwen1.5-0.5B --batch_size 128 --tokenizer=EleutherAI/pythia-14m --push_results_to_hub --upload_model_metadata_to_hf --chat_template raw +``` +Open an issue if you would like complete functionality. + ## Full Installation To install from source, please install `torch` on your system, and then install the following requirements. ``` diff --git a/rewardbench/__init__.py b/rewardbench/__init__.py index 9d7c7434..05d7ea92 100644 --- a/rewardbench/__init__.py +++ b/rewardbench/__init__.py @@ -18,9 +18,9 @@ from .models import DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG from .utils import ( check_tokenizer_chat_template, + load_and_process_dataset, load_bon_dataset, load_eval_dataset, - load_preference_dataset, prepare_dialogue, prepare_dialogue_from_tokenizer, save_to_hub, @@ -33,7 +33,7 @@ DPO_MODEL_CONFIG, load_bon_dataset, load_eval_dataset, - load_preference_dataset, + load_and_process_dataset, prepare_dialogue, prepare_dialogue_from_tokenizer, REWARD_MODEL_CONFIG, diff --git a/rewardbench/rewardbench.py b/rewardbench/rewardbench.py index 682db656..6b52223e 100644 --- a/rewardbench/rewardbench.py +++ b/rewardbench/rewardbench.py @@ -12,22 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Run RewardBench (evaluate any reward model on any dataet) +# Run RewardBench (evaluate any reward model on any dataset) import json import logging import os import sys +import time from dataclasses import dataclass -from typing import Optional +from pprint import pformat +from typing import Dict, List, Optional, Union import numpy as np +import pkg_resources import torch import transformers import wandb from accelerate import Accelerator from accelerate.logging import get_logger -from huggingface_hub import EvalResult, ModelCard, ModelCardData +from datasets import Dataset +from huggingface_hub import EvalResult, HfApi, ModelCard, ModelCardData +from huggingface_hub.repocard import RepoCard from tqdm import tqdm from transformers import AutoTokenizer, HfArgumentParser @@ -35,7 +40,7 @@ DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG, check_tokenizer_chat_template, - load_preference_dataset, + load_and_process_dataset, ) @@ -58,12 +63,22 @@ class Args: """The chat template to use (defaults to from tokenizer, from chattemplate).""" not_quantized: bool = False """Disable quantization for models that are quantized by default.""" + prioritize_scoring: bool = False + """Prioritize scoring of the messages key, rather than accuracy rankings.""" + + # hf saving args + push_results_to_hub: bool = False + """Push distribution of scores and labels to randomly generated HuggingFace dataset.""" + upload_model_metadata_to_hf: bool = False + """Upload metadata to Hugging Face Hub.""" + hf_entity: Optional[str] = None + """The Hugging Face entity to push results to.""" + hf_name: Optional[str] = None + """[Default is random] The Hugging Face dataset name to push results to.""" # wandb args wandb_run: Optional[str] = None """The wandb run to extract model and revision from.""" - upload_metadata_to_hf: bool = False - """Upload metadata to Hugging Face Hub.""" # inference args batch_size: int = 8 @@ -86,15 +101,108 @@ class Args: """Force truncation (for if model errors).""" +def save_jsonl(save_filename: str, table: Dict[str, List[Union[int, float, str]]]): + # Ensure directory exists + dirname = os.path.dirname(save_filename) + if dirname: + os.makedirs(dirname, exist_ok=True) + + # Write the dictionary data to JSONL file + with open(save_filename, "w") as outfile: + # Iterate through each index and write corresponding row as JSON + for i in range(len(next(iter(table.values())))): # Get the first key's length + json.dump({key: table[key][i] for key in table}, outfile) + outfile.write("\n") + + +def push_results_to_hub(args, results, accuracy=None): + """ + Push dataset to Hugging Face Hub. + + Args: + args: Argument object with the following attributes: + - hf_entity: Hugging Face entity (e.g., username or organization). + - hf_name: ID of the repository to create or use. + """ + api = HfApi() + + if args.hf_entity is None: + args.hf_entity = api.whoami()["name"] + + timestamp = time.strftime("%H%M%d%m%y") + # Generate default hf_name if not set + if not args.hf_name: + args.hf_name = f"rewardbench_eval_{timestamp}" + + full_repo_id = f"{args.hf_entity}/{args.hf_name}" + + # Create repository on Hugging Face Hub + api.create_repo(full_repo_id, repo_type="dataset", exist_ok=True) + + # Print and prepare the repository URL + repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}" + + # Generate the command that was run + run_command = " ".join(["python"] + sys.argv) + + # Get package versions as a dictionary + package_versions = {package.key: package.version for package in pkg_resources.working_set} + + # If accuracy is provided, create a string adding it to the results + if accuracy is not None: + accuracy_str = f"Accuracy: {accuracy}" + else: + accuracy_str = "" + + # Create and push a repo card + rm_card = RepoCard( + content=f"""\ +# {args.hf_name}: RewardBench CLI Eval. Outputs + +See https://github.com/allenai/reward-bench for more details + +Built with the `rewardbench` CLI tool. +{accuracy_str} + +Command used to run: +``` +{run_command} +``` + +## Configs +``` +args: {pformat(vars(args))} +``` + +## Package Versions +``` +{pformat(package_versions)} +``` +""" + ) + rm_card.push_to_hub( + full_repo_id, + repo_type="dataset", + ) + print(f"Pushed to {repo_full_url}") + + # Upload the dataset (after to add metadata to card) + data_to_upload = Dataset.from_dict(results) + data_to_upload.push_to_hub(full_repo_id) + + return full_repo_id + + def main(): parser = HfArgumentParser((Args)) - actual_main(*parser.parse_args_into_dataclasses()) + rewardbench(*parser.parse_args_into_dataclasses()) -def actual_main(args: Args): +# Secondary function structure needed to accomodate HuggingFace Args with CLI binding +def rewardbench(args: Args): if args.wandb_run is not None: wandb_run = wandb.Api().run(args.wandb_run) - args.model = wandb_run.config["hf_repo_id"] + args.model = wandb_run.config["hf_name"] args.revision = wandb_run.config["hf_repo_revision"] ############### @@ -192,22 +300,42 @@ def actual_main(args: Args): custom_dialogue_formatting=False, tokenizer=tokenizer, logger=logger, - keep_columns=["text_chosen", "text_rejected", "prompt"], + return_extra_data=True, ) else: - dataset = load_preference_dataset( - args.dataset, split=args.split, json=args.load_json, tokenizer=tokenizer, conv=conv + dataset = load_and_process_dataset( + args.dataset, + split=args.split, + json=args.load_json, + tokenizer=tokenizer, + conv=conv, + prioritize_instructions=args.prioritize_scoring, ) + # check if "chosen" and "rejected" in the dataset features + if "text_chosen" in dataset.features and "text_rejected" in dataset.features: + is_preference_ranking = True + else: + is_preference_ranking = False + if args.debug: dataset = dataset.select(range(10)) + # Move extra columns to extra metadata (merged later) + keep_columns = ["prompt", "text_chosen", "text_rejected"] if is_preference_ranking else ["prompt", "text"] + all_cols = dataset.column_names + metadata = dataset.remove_columns(keep_columns) + dataset = dataset.remove_columns([c for c in all_cols if c not in keep_columns]) + logger.info("*** Load reward model ***") ############################ # Load DPO model pipeline ############################ if is_dpo: + # if not preference data, raise NotImplementedError (only implemented for pairwise) + if not is_preference_ranking: + raise NotImplementedError("DPO only implemented for pairwise preference data.") tokenizer.pad_token = tokenizer.eos_token # if no BOS token, set as pad token, e.g. QWEN models if tokenizer.bos_token is None: @@ -322,100 +450,119 @@ def actual_main(args: Args): ############################ results = [] - scores_chosen = [] - scores_rejected = [] + if is_preference_ranking: + scores_chosen = [] + scores_rejected = [] + for step, batch in enumerate(tqdm(dataloader, desc="RM batch steps")): logger.info(f"RM inference step {step}/{len(dataloader)}") - if is_dpo: - rewards_chosen, rewards_rejected = dpo.inference_step(batch) - else: - rewards_chosen = reward_pipe(batch["text_chosen"], **reward_pipeline_kwargs) - rewards_rejected = reward_pipe(batch["text_rejected"], **reward_pipeline_kwargs) - - # for each item in batch, record 1 if chosen > rejected - # extra score from dict within batched results (e.g. logits) - # [{'label': 'LABEL_1', 'score': 0.6826171875},... ] - if isinstance(rewards_chosen[0], dict): - score_chosen_batch = [result["score"] for result in rewards_chosen] - score_rejected_batch = [result["score"] for result in rewards_rejected] - # for classes that directly output scores (custom code) + if is_preference_ranking: + if is_dpo: + rewards_chosen, rewards_rejected = dpo.inference_step(batch) + else: + rewards_chosen = reward_pipe(batch["text_chosen"], **reward_pipeline_kwargs) + rewards_rejected = reward_pipe(batch["text_rejected"], **reward_pipeline_kwargs) + + # for each item in batch, record 1 if chosen > rejected + # extra score from dict within batched results (e.g. logits) + # [{'label': 'LABEL_1', 'score': 0.6826171875},... ] + if isinstance(rewards_chosen[0], dict): + score_chosen_batch = [result["score"] for result in rewards_chosen] + score_rejected_batch = [result["score"] for result in rewards_rejected] + # for classes that directly output scores (custom code) + else: + score_chosen_batch = rewards_chosen.cpu().numpy().tolist() + score_rejected_batch = rewards_rejected.cpu().numpy().tolist() + + # log results + [ + results.append(1) if chosen > rejected else results.append(0) + for chosen, rejected in zip(score_chosen_batch, score_rejected_batch) + ] + scores_chosen.extend(score_chosen_batch) + scores_rejected.extend(score_rejected_batch) else: - score_chosen_batch = rewards_chosen.cpu().numpy().tolist() - score_rejected_batch = rewards_rejected.cpu().numpy().tolist() - - # log results - [ - results.append(1) if chosen > rejected else results.append(0) - for chosen, rejected in zip(score_chosen_batch, score_rejected_batch) - ] - scores_chosen.extend(score_chosen_batch) - scores_rejected.extend(score_rejected_batch) + rewards = reward_pipe(batch["text"], **reward_pipeline_kwargs) + if isinstance(rewards[0], dict): + scores = [result["score"] for result in rewards] + else: + scores = rewards.cpu().numpy().tolist() + results.extend(scores) ############################ - # compile scores + # save outputs directly ############################ - # calculate accuracy - accuracy = sum(results) / len(results) - logger.info(f"Results: {accuracy}, on {len(results)} prompts") - # compute mean and std of scores, chosen and rejected, then margin between them - logger.info(f"Mean chosen: {np.mean(scores_chosen)}, std: {np.std(scores_chosen)}") - logger.info(f"Mean rejected: {np.mean(scores_rejected)}, std: {np.std(scores_rejected)}") - logger.info(f"Mean margin: {np.mean(np.array(scores_chosen) - np.array(scores_rejected))}") + def unwrap_if_list_of_lists(data): + if isinstance(data, list): + if isinstance(data[0], list): + return [item for sublist in data for item in sublist] + return data - if args.dataset == "allenai/reward-bench": - out_dataset = dataset.add_column("results", results) - if args.debug: - subsets = subsets[:10] - out_dataset = out_dataset.add_column("subsets", subsets) - out_dataset = out_dataset.to_pandas() # I know this is meh - - results_grouped = {} - present_subsets = np.unique(out_dataset["subsets"]) - for subset in present_subsets: - subset_dataset = out_dataset[out_dataset["subsets"] == subset] - num_correct = sum(subset_dataset["results"]) - num_total = len(subset_dataset["results"]) - logger.info(f"{subset}: {num_correct}/{num_total} ({num_correct/num_total})") - results_grouped[subset] = num_correct / num_total - - results_section = calculate_scores_per_section(EXAMPLE_COUNTS, SUBSET_MAPPING, results_grouped) - logger.info(f"Results: {results_section}") + combined_data = { + "prompt": dataset["prompt"], # Assuming `prompts` is a list of prompts matching scores + "results": unwrap_if_list_of_lists(results), + } + + # Consolidate chosen and rejected scores along with prompts and texts + if is_preference_ranking: + combined_data["scores_chosen"] = unwrap_if_list_of_lists(scores_chosen) + combined_data["scores_rejected"] = unwrap_if_list_of_lists(scores_rejected) + combined_data["text_chosen"] = dataset["text_chosen"] + combined_data["text_rejected"] = dataset["text_rejected"] + # or take instruction + else: + combined_data["text"] = dataset["text"] + + # add columns in metadata to combined_data + for col in metadata.column_names: + combined_data[col] = metadata[col] + + # Save combined scores and metadata to JSONL + scores_output_path = os.path.join(args.output_dir, f"{args.model}_outputs.jsonl") + save_jsonl(scores_output_path, combined_data) ############################ - # compile scores + # the rest is just for preferences (accuracies) ############################ - # save score in json to args.output_dir + args.model + ".json" - output_path = args.output_dir + args.model + ".json" - dirname = os.path.dirname(output_path) - os.makedirs(dirname, exist_ok=True) - - # remove old data - if os.path.exists(output_path): - os.remove(output_path) - - final_results = { - "accuracy": accuracy, - "num_prompts": len(results), - "model": args.model, - "ref_model": args.ref_model, - "tokenizer": tokenizer_path, - "chat_template": args.chat_template, - "extra_results": results_grouped if args.dataset == "allenai/reward-bench" else None, - } - with open(output_path, "w") as f: - json.dump(final_results, f) - - if args.wandb_run is not None: - for key in final_results: - wandb_run.summary[f"rewardbench/{key}"] = final_results[key] - wandb_run.update() - print(f"Logged metrics to {wandb_run.url}") - - # if save_all is passed, save a large jsonl with all scores_chosen, scores_rejected - if args.save_all: - output_path = args.output_dir + args.model + "_all.jsonl" + if is_preference_ranking: + ############################ + # compile scores + ############################ + # calculate accuracy + accuracy = sum(results) / len(results) + logger.info(f"Results: {accuracy}, on {len(results)} prompts") + + # compute mean and std of scores, chosen and rejected, then margin between them + logger.info(f"Mean chosen: {np.mean(scores_chosen)}, std: {np.std(scores_chosen)}") + logger.info(f"Mean rejected: {np.mean(scores_rejected)}, std: {np.std(scores_rejected)}") + logger.info(f"Mean margin: {np.mean(np.array(scores_chosen) - np.array(scores_rejected))}") + + if args.dataset == "allenai/reward-bench": + out_dataset = dataset.add_column("results", results) + if args.debug: + subsets = subsets[:10] + out_dataset = out_dataset.add_column("subsets", subsets) + out_dataset = out_dataset.to_pandas() # I know this is meh + + results_grouped = {} + present_subsets = np.unique(out_dataset["subsets"]) + for subset in present_subsets: + subset_dataset = out_dataset[out_dataset["subsets"] == subset] + num_correct = sum(subset_dataset["results"]) + num_total = len(subset_dataset["results"]) + logger.info(f"{subset}: {num_correct}/{num_total} ({num_correct/num_total})") + results_grouped[subset] = num_correct / num_total + + results_section = calculate_scores_per_section(EXAMPLE_COUNTS, SUBSET_MAPPING, results_grouped) + logger.info(f"Results: {results_section}") + + ############################ + # save scores + ############################ + # save score in json to args.output_dir + args.model + ".json" + output_path = args.output_dir + args.model + ".json" dirname = os.path.dirname(output_path) os.makedirs(dirname, exist_ok=True) @@ -423,68 +570,106 @@ def actual_main(args: Args): if os.path.exists(output_path): os.remove(output_path) + final_results = { + "accuracy": accuracy, + "num_prompts": len(results), + "model": args.model, + "ref_model": args.ref_model, + "tokenizer": tokenizer_path, + "chat_template": args.chat_template, + "extra_results": results_grouped if args.dataset == "allenai/reward-bench" else None, + } with open(output_path, "w") as f: - for chosen, rejected in zip(scores_chosen, scores_rejected): - f.write(json.dumps({"chosen": chosen, "rejected": rejected}) + "\n") + json.dump(final_results, f) + + if args.wandb_run is not None: + for key in final_results: + wandb_run.summary[f"rewardbench/{key}"] = final_results[key] + wandb_run.update() + print(f"Logged metrics to {wandb_run.url}") + + # if save_all is passed, save a large jsonl with all scores_chosen, scores_rejected + if args.save_all: + output_path = args.output_dir + args.model + "_all.jsonl" + dirname = os.path.dirname(output_path) + os.makedirs(dirname, exist_ok=True) + + # remove old data + if os.path.exists(output_path): + os.remove(output_path) + + with open(output_path, "w") as f: + for chosen, rejected in zip(scores_chosen, scores_rejected): + f.write(json.dumps({"chosen": chosen, "rejected": rejected}) + "\n") + + ############################ + # Upload metadata to Hugging Face Hub + ############################ + if args.upload_model_metadata_to_hf: + logger.info("*** Uploading metadata to Hugging Face Hub ***") + try: + # Initialize ModelCardData with basic metadata + card_data = ModelCardData( + language="en", + model_name=args.model, + eval_results=[ + EvalResult( + task_type="preference_evaluation", + dataset_type=args.dataset, + dataset_name=args.dataset.split("/")[-1], # Assuming dataset ID is like 'owner/dataset' + metric_type="accuracy", + metric_value=accuracy, + ) + ], + ) + + # If there are extra results (per subset), add them as separate EvalResults + if args.dataset == "allenai/reward-bench" and results_grouped: + for section, section_accuracy in results_section.items(): + print(f"Adding section {section} with accuracy {section_accuracy}") + section_eval = EvalResult( + task_type="preference_evaluation", + dataset_type=section.replace(" ", "_"), + dataset_name=section, + metric_type="accuracy", + metric_value=section_accuracy, + ) + card_data.eval_results.append(section_eval) + + for subset, subset_accuracy in results_grouped.items(): + print(f"Adding subset {subset} with accuracy {subset_accuracy}") + subset_eval = EvalResult( + task_type="preference_evaluation", + dataset_type=subset, + dataset_name=subset, + metric_type="accuracy", + metric_value=subset_accuracy, + ) + card_data.eval_results.append(subset_eval) + + # Create a ModelCard + card = ModelCard.from_template( + card_data, + model_id=args.model, + ) + + # Push the updated ModelCard to the Hugging Face Hub + card.push_to_hub( + args.model, revision=args.revision, commit_message="Update evaluation results via RewardBench" + ) + logger.info(f"Successfully pushed updated ModelCard to Hugging Face Hub for {args.model}") + except Exception as e: + logger.error(f"Failed to upload metadata to Hugging Face Hub: {e}") + logger.info("(The most common issue is a model you do not have write permissions on).") + else: + accuracy = None ############################ - # Upload metadata to Hugging Face Hub + # Upload results to HF (as dataset) ############################ - if args.upload_metadata_to_hf: - logger.info("*** Uploading metadata to Hugging Face Hub ***") - try: - # Initialize ModelCardData with basic metadata - card_data = ModelCardData( - language="en", - model_name=args.model, - eval_results=[ - EvalResult( - task_type="preference_evaluation", - dataset_type=args.dataset, - dataset_name=args.dataset.split("/")[-1], # Assuming dataset ID is like 'owner/dataset' - metric_type="accuracy", - metric_value=accuracy, - ) - ], - ) - - # If there are extra results (per subset), add them as separate EvalResults - if args.dataset == "allenai/reward-bench" and results_grouped: - for section, section_accuracy in results_section.items(): - print(f"Adding section {section} with accuracy {section_accuracy}") - section_eval = EvalResult( - task_type="preference_evaluation", - dataset_type=section.replace(" ", "_"), - dataset_name=section, - metric_type="accuracy", - metric_value=section_accuracy, - ) - card_data.eval_results.append(section_eval) - - for subset, subset_accuracy in results_grouped.items(): - print(f"Adding subset {subset} with accuracy {subset_accuracy}") - subset_eval = EvalResult( - task_type="preference_evaluation", - dataset_type=subset, - dataset_name=subset, - metric_type="accuracy", - metric_value=subset_accuracy, - ) - card_data.eval_results.append(subset_eval) - - # Create a ModelCard - card = ModelCard.from_template( - card_data, - model_id=args.model, - ) - - # Push the updated ModelCard to the Hugging Face Hub - card.push_to_hub( - args.model, revision=args.revision, commit_message="Update evaluation results via RewardBench" - ) - logger.info(f"Successfully pushed updated ModelCard to Hugging Face Hub for {args.model}") - except Exception as e: - logger.error(f"Failed to upload metadata to Hugging Face Hub: {e}") + if args.push_results_to_hub: + hf_repo = push_results_to_hub(args, combined_data, accuracy=accuracy) + logger.info(f"Pushed results to Hugging Face Hub for https://huggingface.co/datasets/{hf_repo}") if __name__ == "__main__": diff --git a/rewardbench/utils.py b/rewardbench/utils.py index e15a7c2e..956bda20 100644 --- a/rewardbench/utils.py +++ b/rewardbench/utils.py @@ -150,31 +150,41 @@ def map_conversations_testsets(example): return example -def load_preference_dataset( +def load_and_process_dataset( dataset_name: str, split: str = "train", json: bool = False, conv: Conversation = None, tokenizer: PreTrainedTokenizer = None, logger: logging.Logger = None, + prioritize_instructions: bool = False, ) -> Dataset: """ - Load a preference dataset from the datasets library. + Load a preference dataset or an instruction dataset from the datasets library. + Works for both preference datasets (with chosen/rejected) and SFT datasets (with messages). - Expects the data the following schema. - - prompt (string): question - - chosen (list): all turns of the conversation (including the prompt), chosen answer - - rejected (list): all turns of the conversation (including the prompt), rejected answer + Expects the data to follow one of these schemas: + 1. Preference data: + - prompt (string): question + - chosen (list): all turns of the conversation (including the prompt), chosen answer + - rejected (list): all turns of the conversation (including the prompt), rejected answer + 2. Instruction data: + - messages (list): all turns of the conversation - Removes all excess columns, only returns scores over the provided data in order. + Removes all excess columns, only returns processed data in order. Args: dataset_name (str): The name of the dataset to load (HuggingFace or local directory) split (str): The split of the dataset to load (train, validation, test, ...) + json (bool): Whether to load the dataset from a JSON file + conv (Conversation): FastChat conversation template + tokenizer (PreTrainedTokenizer): HuggingFace tokenizer + logger (logging.Logger): Logger object + prioritize_instructions (bool): If True, prioritize processing as instruction data when both types are present Returns: - dataset (Dataset): The loaded dataset with prompt, text_chosen, and text_rejected columns. - text_ indicates a full conversation ending with that turn + dataset (Dataset): The loaded dataset with prompt, text_chosen, and text_rejected columns for preference data, + or prompt and response columns for instruction data. """ if json: dataset = load_dataset("json", data_files=dataset_name) @@ -187,86 +197,87 @@ def load_preference_dataset( datasets_to_combine = [dataset[split] for split in available_splits] dataset = concatenate_datasets(datasets_to_combine) - # if has column question without prompt, rename question column to prompt - if "question" in dataset.column_names: - assert "prompt" not in dataset.column_names, "Both prompt and question columns found" + # Handle column renaming to track prompts + if "question" in dataset.column_names and "prompt" not in dataset.column_names: dataset = dataset.rename_column("question", "prompt") - if "input" in dataset.column_names: - assert "prompt" not in dataset.column_names, "Both prompt and question columns found" + if "input" in dataset.column_names and "prompt" not in dataset.column_names: dataset = dataset.rename_column("input", "prompt") - # switch to format used for data utils - # e.g. for evaluating this data https://huggingface.co/datasets/allenai/preference-test-sets - # python -m rewardbench/rewardbench.py --dataset-name allenai/preference-test-sets --split shp features = dataset.features - def switch_format(example): - # chosen/rejected append {"role": "assistnat", "content": chosen} + # Determine if it's preference data or instruction data + has_preference_data = "chosen" in dataset.column_names and "rejected" in dataset.column_names + has_instruction_data = "messages" in dataset.column_names + + # Decide which processing to use based on the prioritize_instructions flag + if prioritize_instructions and has_instruction_data: + is_preference_data = False + if logger: + logger.info("Processing as instruction data (prioritized)") + elif has_preference_data: + is_preference_data = True + if logger: + logger.info("Processing as preference data") + elif has_instruction_data: + is_preference_data = False + if logger: + logger.info("Processing as instruction data") + else: + raise ValueError( + "Dataset format not recognized. It should contain either 'chosen' and 'rejected'" + " columns for preference data, or a 'messages' column for instruction data." + ) + + # Process the data for input to RM + def process_preference_data(example): example["prompt"] = example["chosen"][:-1] example["chosen"] = example["chosen"][-1]["content"] example["rejected"] = example["rejected"][-1]["content"] return example - # NOTE: We do NOT want to support every schema. These are the main three to start with - # 1. Prompt is in a list of previous turns, chosen and rejected are final message from assistant - # 2. Prompt is a string, chosen and rejected are full conversations with different final turns - # 3. Prompt is not existent, chosen and rejected are full conversations with different final turns - # TODO implement system prompts correctly (though, often doesn't work for Reward Models) + def process_instruction_data(example): + messages = example["messages"] + example["prompt"] = messages[0]["content"] + return example - # if prompt isn't a column, - if "prompt" not in dataset.column_names: - dataset = dataset.map( - switch_format, - num_proc=8, - load_from_cache_file=False, - ) - # elif prompt is a list and not a str, same function works - elif not isinstance(features["prompt"], list): + if is_preference_data: + if "prompt" not in dataset.column_names or not isinstance(features["prompt"], list): + dataset = dataset.map( + process_preference_data, + num_proc=8, + load_from_cache_file=False, + ) + else: dataset = dataset.map( - switch_format, + process_instruction_data, num_proc=8, load_from_cache_file=False, ) - # update features - features = dataset.features - - # assert the correct types - assert features["chosen"].dtype == "string", f"chosen is wrong type (should be string): {features['chosen']}" - assert features["rejected"].dtype == "string", f"rejected is wrong type (should be string): {features['rejected']}" - - # tokenize the data + # Tokenize the data usable_tokenizer = check_tokenizer_chat_template(tokenizer) - # assert either conv is passed or tokenizer has chat_template - assert conv is not None or usable_tokenizer + assert conv is not None or usable_tokenizer, "Either conv or a tokenizer with a chat template must be provided." if usable_tokenizer: if logger is not None: logger.info("*** Preparing dataset with HF Transformers ***") - # docs https://huggingface.co/docs/transformers/main/en/chat_templating dataset = dataset.map( prepare_dialogue_from_tokenizer, - fn_kwargs={"tokenizer": tokenizer}, + fn_kwargs={"tokenizer": tokenizer, "ift": not is_preference_data}, num_proc=8, load_from_cache_file=False, ) - - # else use FastChat to get chat template else: if logger is not None: logger.info("*** Preparing dataset with FastChat ***") dataset = dataset.map( prepare_dialogue, - fn_kwargs={"dialogue_template": conv}, + fn_kwargs={"dialogue_template": conv, "ift": not is_preference_data}, num_proc=8, load_from_cache_file=False, ) - # remove excess data - keep_columns = ["prompt", "text_chosen", "text_rejected"] - all_cols = dataset.column_names - dataset = dataset.remove_columns([c for c in all_cols if c not in keep_columns]) return dataset @@ -277,19 +288,21 @@ def load_eval_dataset( tokenizer: PreTrainedTokenizer = None, logger: logging.Logger = None, keep_columns: List[str] = ["text_chosen", "text_rejected", "id"], + return_extra_data: bool = False, max_turns: int = None, ) -> tuple[Dataset, list[str]]: """ - Loads either the core eval set for HERM or the existing preference data test sets. + Loads either the core eval set for RewardBench or the existing preference data test sets. Args: - core_set: if True, load the core eval set for HERM. + core_set: if True, load the core eval set for RewardBench. custom_dialogue_formatting: if True, format the dialogue as needed for custom models (e.g. SHP and PairRM). conv: fastchat conversation template. If None (default) the passed tokenizer needs to have a usable chat template. tokenizer: HuggingFace tokenizer to use. The tokenizer's chat template, if available, has precedence over conv. logger: logger to use for logging. If None (default), no logging is done. keep_columns: list of columns to keep in the dataset. + return_extra_data: return extra metadata for expanded logging (mostly in CLI) max_turns: maximum number of turns in the dialogue (usually even). If None (default), no filtering is done. Returns: @@ -384,6 +397,8 @@ def filter_long_turns(batch): # take column subset from dataset subsets = dataset["subset"] + if return_extra_data: + return dataset, subsets # remove columns if set and not custom_dialogue_formatting all_cols = dataset.column_names @@ -580,11 +595,13 @@ def prepare_dialogue_from_tokenizer( ) example["prompt"] = temp_prompt elif ift: - # TODO adapt this for DPO models with tokenize_row function - messages = [ - {"role": "user", "content": example["prompt"]}, - {"role": "assistant", "content": example["input"]}, - ] + if "messages" in example: + messages = example["messages"] + else: + messages = [ + {"role": "user", "content": example["prompt"]}, + {"role": "assistant", "content": example["input"]}, + ] example["text"] = tokenizer.apply_chat_template( messages, tokenize=False, @@ -655,14 +672,29 @@ def prepare_dialogue( if isinstance(example["prompt"], list): example["prompt"] = example["prompt"][0] + # get prompt dialogue_template.messages = [ [dialogue_template.roles[0], example["prompt"]], ] temp_prompt = dialogue_template.get_prompt() - dialogue_template.messages = [ - [dialogue_template.roles[0], example["prompt"]], - [dialogue_template.roles[1], example["input"]], - ] + + # get messages + if "messages" in example: + # convert to FastChat format (list of list) + # original format: + # [ + # {"role": "user", "content": example["prompt"]}, + # {"role": "assistant", "content": example["rejected"]}, + # ] + dialogue_template.messages = [] + for i, line in enumerate(example["messages"]): + role = dialogue_template.roles[0] if i % 2 == 0 else dialogue_template.roles[1] + dialogue_template.messages.append([role, line["content"]]) + else: + dialogue_template.messages = [ + [dialogue_template.roles[0], example["prompt"]], + [dialogue_template.roles[1], example["input"]], + ] example["text"] = dialogue_template.get_prompt() example["prompt"] = temp_prompt # needed for DPO diff --git a/tests/test_data.py b/tests/test_data.py index a6ec6920..0756f03f 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -78,6 +78,17 @@ def test_prepare_dialogue_from_tokenizer_ift(self): desired_text = "<|user|>\nWhat are different drawers I should have for clothes?<|endoftext|>\n<|assistant|>\nUtensils!<|endoftext|>\n" # noqa assert prepared["text"] == desired_text + def test_prepare_dialogue_from_tokenizer_messages_ift(self): + example = {} + example["messages"] = [ + {"role": "user", "content": "Who are you?"}, + {"role": "assistant", "content": "I am a bot."}, + ] + example["prompt"] = "Who are you?" + prepared = prepare_dialogue_from_tokenizer(example, self.tokenizer, ift=True) + desired_text = "<|user|>\nWho are you?<|endoftext|>\n<|assistant|>\nI am a bot.<|endoftext|>\n" + assert prepared["text"] == desired_text + def test_prepare_dialogue_single_turn(self): example = {} example["prompt"] = "What are different drawers I should have for clothes?" @@ -126,6 +137,17 @@ def test_prepare_dialogue_ift(self): desired_text = "<|user|>\nWhat are different drawers I should have for clothes?\n<|assistant|>\nUtensils!\n" assert prepared["text"] == desired_text + def test_prepare_dialogue_messages_ift(self): + example = {} + example["messages"] = [ + {"role": "user", "content": "Who are you?"}, + {"role": "assistant", "content": "I am a bot."}, + ] + example["prompt"] = "Who are you?" + prepared = prepare_dialogue(example, self.conv, ift=True) + desired_text = "<|user|>\nWho are you?\n<|assistant|>\nI am a bot.\n" + assert prepared["text"] == desired_text + class DatasetTest(unittest.TestCase): def test_core_dataset_lens(self): diff --git a/tests/test_package.py b/tests/test_package.py index 80f01423..4d6a6a39 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -18,7 +18,7 @@ from fastchat.conversation import get_conv_template from transformers import AutoTokenizer -from rewardbench import load_preference_dataset +from rewardbench import load_and_process_dataset class LoadAnyDataTest(unittest.TestCase): @@ -31,15 +31,21 @@ def setUp(self): self.conv = get_conv_template("tulu") def test_load_standard_tokenizer(self): - load_preference_dataset( + load_and_process_dataset( "allenai/ultrafeedback_binarized_cleaned", split="test_prefs", tokenizer=self.tokenizer ) def test_load_standard_conv(self): - load_preference_dataset("allenai/ultrafeedback_binarized_cleaned", split="test_prefs", conv=self.conv) + load_and_process_dataset("allenai/ultrafeedback_binarized_cleaned", split="test_prefs", conv=self.conv) def test_load_alt_tokenizer(self): - load_preference_dataset("allenai/preference-test-sets", split="shp", tokenizer=self.tokenizer) + load_and_process_dataset("allenai/preference-test-sets", split="shp", tokenizer=self.tokenizer) def test_load_alt_conv(self): - load_preference_dataset("allenai/preference-test-sets", split="shp", conv=self.conv) + load_and_process_dataset("allenai/preference-test-sets", split="shp", conv=self.conv) + + def test_load_sft_tokenizer(self): + load_and_process_dataset("HuggingFaceH4/no_robots", split="test", tokenizer=self.tokenizer) + + def test_load_sft_conv(self): + load_and_process_dataset("HuggingFaceH4/no_robots", split="test", conv=self.conv)