From cca53fd01ffdac52d197e8bc12f4339a0a295a33 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 29 Sep 2025 10:23:19 -0700 Subject: [PATCH 1/5] updated filtering script whoops remove things --- .../filter_existing_dataset_correctness.py | 72 ++++++++++++++++--- 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index 65359176c..7303a5cb1 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -4,12 +4,10 @@ requires reward functions setup. we use multiprocessing to make things actually fast. -to run: -python scripts/data/rlvr/filter_existing_dataset_correctness.py \ - --files data/*.jsonl --output_file filtered.jsonl - If you have code data, you might have to launch code server too before running: source configs/beaker_configs/code_api_setup.sh + +You might have to explicitly install nginx first: sudo apt-get update && apt-get install -y --no-install-recommends nginx """ import os import argparse @@ -24,17 +22,22 @@ from open_instruct.ground_truth_utils import build_all_verifiers -def _avg_correctness(sample, reward_fn_mapping): +def _avg_correctness(sample, reward_fn_mapping, judge_override=None): """ Compute the mean correctness for one sample (called in worker). """ dataset = sample["dataset"][0] if isinstance(sample["dataset"], list) else sample["dataset"] gt = sample["ground_truth"][0] if isinstance(sample["ground_truth"], list) else sample["ground_truth"] - outputs = sample["output"] + outputs = sample["output"] if "output" in sample else sample["outputs"] - reward_fn = reward_fn_mapping[dataset] - scores = [reward_fn(None, o, gt).score for o in outputs] - return sum(scores) / len(scores) if scores else 0.0 + query = None + if "messages" in sample: + query = "\n".join(f"{msg['role']}: {msg['content']}" for msg in sample["messages"]) + + key = judge_override if judge_override is not None else dataset + reward_fn = reward_fn_mapping[key] + scores = [reward_fn(None, o, gt, query).score for o in outputs] + return sum(scores) / len(scores) if scores else 0.0, len(scores) def load_samples(files): @@ -162,14 +165,48 @@ def main(): default="train", help="Split to use on upload" ) + parser.add_argument( + "--annotate_original_dataset", + type=str, + default=None, + help="If set, annotate the original dataset with the passrates, and save to this path." + ) + parser.add_argument( + "--judge_override", + type=str, + default=None, + help=( + "If set, use this judge/verifier for all samples instead of the dataset-provided one. " + "Accepts keys from build_all_verifiers(), e.g. 'math', 'string_f1', 'code', or 'general-quality'. " + "For LLM judges, you may also pass just the judge type like 'quality' which will map to 'general-quality'." + ), + ) args = parser.parse_args() if args.lower_bound == 0 and args.upper_bound == 1: print("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?") reward_fn_mapping = build_all_verifiers(args) + # Resolve judge override if provided + override_key = None + if args.judge_override is not None: + candidate = args.judge_override.lower() + if candidate not in reward_fn_mapping and f"general-{candidate}" in reward_fn_mapping: + candidate = f"general-{candidate}" + if candidate not in reward_fn_mapping: + raise ValueError( + f"Judge override '{args.judge_override}' not found in available verifiers. " + f"Try one of: {', '.join(sorted(reward_fn_mapping.keys()))}" + ) + override_key = candidate + print(f"Using judge override: {override_key}") + # currying the avg_correctness function - avg_correctness = partial(_avg_correctness, reward_fn_mapping=reward_fn_mapping) + avg_correctness = partial( + _avg_correctness, + reward_fn_mapping=reward_fn_mapping, + judge_override=override_key, + ) # Prefer 'spawn' for better safety on macOS / Jupyter try: @@ -184,13 +221,16 @@ def main(): chunk_size = 1 # Tune for workload size with Pool(processes=workers) as pool: - avg_scores = list( + results = list( tqdm( pool.imap(avg_correctness, samples, chunksize=chunk_size), total=len(samples), desc="Scoring" ) ) + # results is a list of tuples: (avg_score, num_rollouts) + avg_scores = [score for score, _ in results] + num_rollouts = [n for _, n in results] # Simple diagnostic plot plt.hist(avg_scores, bins=100) @@ -216,6 +256,16 @@ def main(): for sample in filtered_samples: f.write(json.dumps(sample) + "\n") + # Annotate the original dataset with the passrates if requested, and save. + if args.annotate_original_dataset is not None: + for sample, num_r, score in zip(samples, num_rollouts, avg_scores): + sample["total_rollouts"] = num_r + sample["total_correct_rollouts"] = score * num_r + sample["passrate"] = score + with open(args.annotate_original_dataset, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + if args.push_to_hub is not None: dataset = Dataset.from_list(filtered_samples) dataset.push_to_hub(args.push_to_hub) From fa35a1e9f5a7350665ca826a141e893473ef4dd3 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 29 Sep 2025 10:25:18 -0700 Subject: [PATCH 2/5] add command --- scripts/data/rlvr/filter_existing_dataset_correctness.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index 7303a5cb1..c79c58f7e 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -4,6 +4,10 @@ requires reward functions setup. we use multiprocessing to make things actually fast. +to run: +python scripts/data/rlvr/filter_existing_dataset_correctness.py \ + --files data/*.jsonl --output_file filtered.jsonl + If you have code data, you might have to launch code server too before running: source configs/beaker_configs/code_api_setup.sh From 007e306882120f8f8d9f3b56b7b1410f794c0eea Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 7 Oct 2025 09:06:54 -0400 Subject: [PATCH 3/5] address comments --- .../filter_existing_dataset_correctness.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index c79c58f7e..456689746 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -23,12 +23,22 @@ import matplotlib.pyplot as plt from tqdm import tqdm -from open_instruct.ground_truth_utils import build_all_verifiers +from open_instruct.ground_truth_utils import build_all_verifiers, VerifierFunction +from open_instruct import logger_utils +logger = logger_utils.setup_logger(__name__) -def _avg_correctness(sample, reward_fn_mapping, judge_override=None): + +def _avg_correctness(sample: dict, reward_fn_mapping: dict[str, VerifierFunction], judge_override: str | None = None): """ Compute the mean correctness for one sample (called in worker). + Args: + sample: The sample to compute the correctness for. Should have "dataset", "ground_truth", and "output" keys. Output should be a list of strings (list of completions for the sample). + reward_fn_mapping: The reward function mapping. Should be a dictionary of verifier names to verifier functions objects. + judge_override: If specified, use this judge/verifier for all samples instead of the dataset-provided one. + Returns: + The average score of outputs as judged by the verifier function. If there are no outputs, return 0.0. + The number of outputs. """ dataset = sample["dataset"][0] if isinstance(sample["dataset"], list) else sample["dataset"] gt = sample["ground_truth"][0] if isinstance(sample["ground_truth"], list) else sample["ground_truth"] @@ -187,7 +197,7 @@ def main(): ) args = parser.parse_args() if args.lower_bound == 0 and args.upper_bound == 1: - print("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?") + logger.warning("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?") reward_fn_mapping = build_all_verifiers(args) @@ -203,7 +213,7 @@ def main(): f"Try one of: {', '.join(sorted(reward_fn_mapping.keys()))}" ) override_key = candidate - print(f"Using judge override: {override_key}") + logger.info(f"Using judge override: {override_key}") # currying the avg_correctness function avg_correctness = partial( @@ -225,12 +235,10 @@ def main(): chunk_size = 1 # Tune for workload size with Pool(processes=workers) as pool: - results = list( - tqdm( - pool.imap(avg_correctness, samples, chunksize=chunk_size), - total=len(samples), - desc="Scoring" - ) + results = tqdm( + pool.imap(avg_correctness, samples, chunksize=chunk_size), + total=len(samples), + desc="Scoring" ) # results is a list of tuples: (avg_score, num_rollouts) avg_scores = [score for score, _ in results] @@ -250,7 +258,7 @@ def main(): sample for sample, score in zip(samples, avg_scores) if lower_bound <= score <= upper_bound ] - print( + logger.info( f"Filtered {len(samples) - len(filtered_samples)} samples out of {len(samples)}" ) From 14b5a21f866d95dd4a994a949e919d4d5a794e46 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 7 Oct 2025 09:10:37 -0400 Subject: [PATCH 4/5] need list --- .../data/rlvr/filter_existing_dataset_correctness.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index 456689746..6c524d12f 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -235,10 +235,12 @@ def main(): chunk_size = 1 # Tune for workload size with Pool(processes=workers) as pool: - results = tqdm( - pool.imap(avg_correctness, samples, chunksize=chunk_size), - total=len(samples), - desc="Scoring" + results = list( + tqdm( + pool.imap(avg_correctness, samples, chunksize=chunk_size), + total=len(samples), + desc="Scoring" + ) ) # results is a list of tuples: (avg_score, num_rollouts) avg_scores = [score for score, _ in results] From ae6de99acb23a1bbd240601be5fe65838f1d9d7e Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 7 Oct 2025 10:13:14 -0400 Subject: [PATCH 5/5] return type --- scripts/data/rlvr/filter_existing_dataset_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data/rlvr/filter_existing_dataset_correctness.py b/scripts/data/rlvr/filter_existing_dataset_correctness.py index 6c524d12f..94c257f08 100644 --- a/scripts/data/rlvr/filter_existing_dataset_correctness.py +++ b/scripts/data/rlvr/filter_existing_dataset_correctness.py @@ -29,7 +29,7 @@ logger = logger_utils.setup_logger(__name__) -def _avg_correctness(sample: dict, reward_fn_mapping: dict[str, VerifierFunction], judge_override: str | None = None): +def _avg_correctness(sample: dict, reward_fn_mapping: dict[str, VerifierFunction], judge_override: str | None = None) -> tuple[float, int]: """ Compute the mean correctness for one sample (called in worker). Args: