Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions scripts/data/rlvr/filter_existing_dataset_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

If you have code data, you might have to launch code server too before running:
source configs/beaker_configs/code_api_setup.sh

You might have to explicitly install nginx first: sudo apt-get update && apt-get install -y --no-install-recommends nginx
"""
import os
import argparse
Expand All @@ -21,20 +23,35 @@
import matplotlib.pyplot as plt
from tqdm import tqdm

from open_instruct.ground_truth_utils import build_all_verifiers
from open_instruct.ground_truth_utils import build_all_verifiers, VerifierFunction
from open_instruct import logger_utils

logger = logger_utils.setup_logger(__name__)


def _avg_correctness(sample, reward_fn_mapping):
def _avg_correctness(sample: dict, reward_fn_mapping: dict[str, VerifierFunction], judge_override: str | None = None) -> tuple[float, int]:
"""
Compute the mean correctness for one sample (called in worker).
Args:
sample: The sample to compute the correctness for. Should have "dataset", "ground_truth", and "output" keys. Output should be a list of strings (list of completions for the sample).
reward_fn_mapping: The reward function mapping. Should be a dictionary of verifier names to verifier functions objects.
judge_override: If specified, use this judge/verifier for all samples instead of the dataset-provided one.
Returns:
The average score of outputs as judged by the verifier function. If there are no outputs, return 0.0.
The number of outputs.
"""
dataset = sample["dataset"][0] if isinstance(sample["dataset"], list) else sample["dataset"]
gt = sample["ground_truth"][0] if isinstance(sample["ground_truth"], list) else sample["ground_truth"]
outputs = sample["output"]
outputs = sample["output"] if "output" in sample else sample["outputs"]

reward_fn = reward_fn_mapping[dataset]
scores = [reward_fn(None, o, gt).score for o in outputs]
return sum(scores) / len(scores) if scores else 0.0
query = None
if "messages" in sample:
query = "\n".join(f"{msg['role']}: {msg['content']}" for msg in sample["messages"])

key = judge_override if judge_override is not None else dataset
reward_fn = reward_fn_mapping[key]
scores = [reward_fn(None, o, gt, query).score for o in outputs]
return sum(scores) / len(scores) if scores else 0.0, len(scores)


def load_samples(files):
Expand Down Expand Up @@ -162,14 +179,48 @@ def main():
default="train",
help="Split to use on upload"
)
parser.add_argument(
"--annotate_original_dataset",
type=str,
default=None,
help="If set, annotate the original dataset with the passrates, and save to this path."
)
parser.add_argument(
"--judge_override",
type=str,
default=None,
help=(
"If set, use this judge/verifier for all samples instead of the dataset-provided one. "
"Accepts keys from build_all_verifiers(), e.g. 'math', 'string_f1', 'code', or 'general-quality'. "
"For LLM judges, you may also pass just the judge type like 'quality' which will map to 'general-quality'."
),
)
args = parser.parse_args()
if args.lower_bound == 0 and args.upper_bound == 1:
print("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?")
logger.warning("Upper bound is 1 and lower bound is 0. No filtering will be done, is this intended?")

reward_fn_mapping = build_all_verifiers(args)

# Resolve judge override if provided
override_key = None
if args.judge_override is not None:
candidate = args.judge_override.lower()
if candidate not in reward_fn_mapping and f"general-{candidate}" in reward_fn_mapping:
candidate = f"general-{candidate}"
if candidate not in reward_fn_mapping:
raise ValueError(
f"Judge override '{args.judge_override}' not found in available verifiers. "
f"Try one of: {', '.join(sorted(reward_fn_mapping.keys()))}"
)
override_key = candidate
logger.info(f"Using judge override: {override_key}")

# currying the avg_correctness function
avg_correctness = partial(_avg_correctness, reward_fn_mapping=reward_fn_mapping)
avg_correctness = partial(
_avg_correctness,
reward_fn_mapping=reward_fn_mapping,
judge_override=override_key,
)

# Prefer 'spawn' for better safety on macOS / Jupyter
try:
Expand All @@ -184,13 +235,16 @@ def main():
chunk_size = 1 # Tune for workload size

with Pool(processes=workers) as pool:
avg_scores = list(
results = list(
tqdm(
pool.imap(avg_correctness, samples, chunksize=chunk_size),
total=len(samples),
desc="Scoring"
)
)
# results is a list of tuples: (avg_score, num_rollouts)
avg_scores = [score for score, _ in results]
num_rollouts = [n for _, n in results]

# Simple diagnostic plot
plt.hist(avg_scores, bins=100)
Expand All @@ -206,7 +260,7 @@ def main():
sample for sample, score in zip(samples, avg_scores)
if lower_bound <= score <= upper_bound
]
print(
logger.info(
f"Filtered {len(samples) - len(filtered_samples)} samples out of {len(samples)}"
)

Expand All @@ -216,6 +270,16 @@ def main():
for sample in filtered_samples:
f.write(json.dumps(sample) + "\n")

# Annotate the original dataset with the passrates if requested, and save.
if args.annotate_original_dataset is not None:
for sample, num_r, score in zip(samples, num_rollouts, avg_scores):
sample["total_rollouts"] = num_r
sample["total_correct_rollouts"] = score * num_r
sample["passrate"] = score
with open(args.annotate_original_dataset, "w") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")

if args.push_to_hub is not None:
dataset = Dataset.from_list(filtered_samples)
dataset.push_to_hub(args.push_to_hub)
Expand Down