diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index 65359176c..94c257f08 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -10,6 +10,8 @@ If you have code data, you might have to launch code server too before running: source configs/beaker_configs/code_api_setup.sh + +You might have to explicitly install nginx first: sudo apt-get update && apt-get install -y --no-install-recommends nginx """ import os import argparse @@ -21,20 +23,35 @@ import matplotlib.pyplot as plt from tqdm import tqdm -from open_instruct.ground_truth_utils import build_all_verifiers +from open_instruct.ground_truth_utils import build_all_verifiers, VerifierFunction +from open_instruct import logger_utils + +logger = logger_utils.setup_logger(__name__) -def _avg_correctness(sample, reward_fn_mapping): +def _avg_correctness(sample: dict, reward_fn_mapping: dict[str, VerifierFunction], judge_override: str | None = None) -> tuple[float, int]: """ Compute the mean correctness for one sample (called in worker). + Args: + sample: The sample to compute the correctness for. Should have "dataset", "ground_truth", and "output" keys. Output should be a list of strings (list of completions for the sample). + reward_fn_mapping: The reward function mapping. Should be a dictionary of verifier names to verifier functions objects. + judge_override: If specified, use this judge/verifier for all samples instead of the dataset-provided one. + Returns: + The average score of outputs as judged by the verifier function. If there are no outputs, return 0.0. + The number of outputs. """ dataset = sample["dataset"][0] if isinstance(sample["dataset"], list) else sample["dataset"] gt = sample["ground_truth"][0] if isinstance(sample["ground_truth"], list) else sample["ground_truth"] - outputs = sample["output"] + outputs = sample["output"] if "output" in sample else sample["outputs"] - reward_fn = reward_fn_mapping[dataset] - scores = [reward_fn(None, o, gt).score for o in outputs] - return sum(scores) / len(scores) if scores else 0.0 + query = None + if "messages" in sample: + query = "\n".join(f"{msg['role']}: {msg['content']}" for msg in sample["messages"]) + + key = judge_override if judge_override is not None else dataset + reward_fn = reward_fn_mapping[key] + scores = [reward_fn(None, o, gt, query).score for o in outputs] + return sum(scores) / len(scores) if scores else 0.0, len(scores) def load_samples(files): @@ -162,14 +179,48 @@ def main(): default="train", help="Split to use on upload" ) + parser.add_argument( + "--annotate_original_dataset", + type=str, + default=None, + help="If set, annotate the original dataset with the passrates, and save to this path." + ) + parser.add_argument( + "--judge_override", + type=str, + default=None, + help=( + "If set, use this judge/verifier for all samples instead of the dataset-provided one. " + "Accepts keys from build_all_verifiers(), e.g. 'math', 'string_f1', 'code', or 'general-quality'. " + "For LLM judges, you may also pass just the judge type like 'quality' which will map to 'general-quality'." + ), + ) args = parser.parse_args() if args.lower_bound == 0 and args.upper_bound == 1: - print("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?") + logger.warning("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?") reward_fn_mapping = build_all_verifiers(args) + # Resolve judge override if provided + override_key = None + if args.judge_override is not None: + candidate = args.judge_override.lower() + if candidate not in reward_fn_mapping and f"general-{candidate}" in reward_fn_mapping: + candidate = f"general-{candidate}" + if candidate not in reward_fn_mapping: + raise ValueError( + f"Judge override '{args.judge_override}' not found in available verifiers. " + f"Try one of: {', '.join(sorted(reward_fn_mapping.keys()))}" + ) + override_key = candidate + logger.info(f"Using judge override: {override_key}") + # currying the avg_correctness function - avg_correctness = partial(_avg_correctness, reward_fn_mapping=reward_fn_mapping) + avg_correctness = partial( + _avg_correctness, + reward_fn_mapping=reward_fn_mapping, + judge_override=override_key, + ) # Prefer 'spawn' for better safety on macOS / Jupyter try: @@ -184,13 +235,16 @@ def main(): chunk_size = 1 # Tune for workload size with Pool(processes=workers) as pool: - avg_scores = list( + results = list( tqdm( pool.imap(avg_correctness, samples, chunksize=chunk_size), total=len(samples), desc="Scoring" ) ) + # results is a list of tuples: (avg_score, num_rollouts) + avg_scores = [score for score, _ in results] + num_rollouts = [n for _, n in results] # Simple diagnostic plot plt.hist(avg_scores, bins=100) @@ -206,7 +260,7 @@ def main(): sample for sample, score in zip(samples, avg_scores) if lower_bound <= score <= upper_bound ] - print( + logger.info( f"Filtered {len(samples) - len(filtered_samples)} samples out of {len(samples)}" ) @@ -216,6 +270,16 @@ def main(): for sample in filtered_samples: f.write(json.dumps(sample) + "\n") + # Annotate the original dataset with the passrates if requested, and save. + if args.annotate_original_dataset is not None: + for sample, num_r, score in zip(samples, num_rollouts, avg_scores): + sample["total_rollouts"] = num_r + sample["total_correct_rollouts"] = score * num_r + sample["passrate"] = score + with open(args.annotate_original_dataset, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + if args.push_to_hub is not None: dataset = Dataset.from_list(filtered_samples) dataset.push_to_hub(args.push_to_hub)