Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ class Args:
"""Where to save the model"""
save_traces: bool = False
"""Whether to save learning data traces"""
error_on_all_filtered: bool = False
"""Whether to raise an error if all responses are filtered out during packing"""
cache_dataset_only: bool = False
"""Immediately exit after caching the dataset"""
keep_last_n_checkpoints: int = 3
Expand Down Expand Up @@ -1526,6 +1528,18 @@ def data_preparation_thread(
result.responses[i].append(tokenizer.eos_token_id)
result.masks[i].append(1) # never mask the eos token for now?

# Debug: Print first 20 tokens of each generation before filtering
logger.info(
f"[debug] [Data Prep] Step {training_step}: Received {len(result.responses)} responses before filtering"
)
for i in range(min(10, len(result.responses))): # Show first 10 responses to avoid spam
response_tokens = result.responses[i][:20] # First 20 tokens
decoded_preview = tokenizer.decode(response_tokens, skip_special_tokens=False)
logger.info(
f"[debug] Response {i}: finish_reason='{result.finish_reasons[i]}', "
f"len={len(result.responses[i])}, preview: {decoded_preview!r}"
)

with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
decoded_queries = tokenizer.batch_decode(batch.queries, skip_special_tokens=True)
Expand Down Expand Up @@ -1572,14 +1586,41 @@ def data_preparation_thread(
real_batch_size_ratio = non_zero_std_mask.sum() * args.num_samples_per_prompt_rollout / len(scores)
expanded_mask = np.repeat(non_zero_std_mask, args.num_samples_per_prompt_rollout)
non_zero_gradient_index = np.where(expanded_mask)[0]

# Debug: Show examples of filtered responses
filtered_indices = np.where(~expanded_mask)[0]
if len(filtered_indices) > 0:
logger.info(f"[debug] Showing examples of filtered responses (zero std):")
# Show up to 5 filtered prompts
filtered_prompt_indices = np.where(~non_zero_std_mask)[0][:5]
for prompt_idx in filtered_prompt_indices:
# Get the responses for this prompt
start_idx = prompt_idx * args.num_samples_per_prompt_rollout
end_idx = start_idx + args.num_samples_per_prompt_rollout
prompt_scores = scores_per_prompt[prompt_idx]
prompt_responses = result.responses[start_idx:end_idx]
logger.info(f" Prompt {prompt_idx}: scores={prompt_scores.tolist()}, std={prompt_scores.std():.6f}")
for i, resp in enumerate(prompt_responses[:3]): # Show first 3 responses
# Truncate response for display
resp_preview = resp[:100] + "..." if len(resp) > 100 else resp
logger.info(f" Response {i}: {repr(resp_preview)}")

advantages = advantages[non_zero_gradient_index]
original_batch_size = len(scores)
scores = scores[non_zero_gradient_index]
responses = [result.responses[i] for i in non_zero_gradient_index]
masks = [result.masks[i] for i in non_zero_gradient_index]
batch = batch[non_zero_gradient_index.tolist()]
finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index]

# Debug: Log filtering stats
logger.info(
f"[debug] [Data Prep] Step {training_step}: After zero-std filtering: "
f"{original_batch_size} -> {len(scores)} responses "
f"({original_batch_size - len(scores)} filtered due to zero std)"
)
if args.mask_truncated_completions:
before_truncate_filter = len(scores)
stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"])
scores = scores[stop_idxes]
advantages = advantages[stop_idxes]
Expand All @@ -1588,6 +1629,13 @@ def data_preparation_thread(
batch = batch[stop_idxes.tolist()]
finish_reasons = [finish_reasons[i] for i in stop_idxes]

# Debug: Log truncated completion filtering
logger.info(
f"[debug] [Data Prep] Step {training_step}: After truncated completion filter: "
f"{before_truncate_filter} -> {len(scores)} responses "
f"({before_truncate_filter - len(scores)} filtered due to non-stop finish reason)"
)

if args.fill_completions:
with Timer("⏱ [Data Preparation Thread] Refill completions"):
current_batch_size = len(scores)
Expand Down Expand Up @@ -1787,6 +1835,11 @@ def data_preparation_thread(
f.write("\n")

if len(responses) == 0:
if args.error_on_all_filtered:
raise ValueError(
f"All responses were filtered out in batch {training_step}. "
f"Set error_on_all_filtered=False to continue with a warning instead."
)
logger.warning(f"No responses in batch {training_step}.")

# Put the packed sequences and metrics into the output queue
Expand Down Expand Up @@ -2089,7 +2142,11 @@ def prepare_prompts(


def load_data_from_packing_thread(
packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None]
packed_sequences_Q: Queue,
num_total_tokens: int,
stop_event: threading.Event,
health_check_fn: Callable[[], None],
args: Args,
):
"""Get the packed sequences with advantages from the packing thread."""
with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer:
Expand All @@ -2111,6 +2168,11 @@ def load_data_from_packing_thread(

data_thread_metrics["time/trainer_idling"] = timer.duration
if B == 0:
if args.error_on_all_filtered:
raise ValueError(
"[Main Thread] After packing, there is not enough data to train. "
"All responses were filtered out. Set error_on_all_filtered=False to continue with a warning instead."
)
logger.warning("[Main Thread] 🤡 After packing, there is not enough data to train")
return None, data_thread_metrics, num_total_tokens
return collated_data, data_thread_metrics, num_total_tokens
Expand Down Expand Up @@ -2684,7 +2746,7 @@ def health_check_fn():

# The generate_thread is now handling vLLM processing asynchronously
collated_data, data_thread_metrics, num_total_tokens = load_data_from_packing_thread(
packed_sequences_Q, num_total_tokens, stop_event, health_check_fn
packed_sequences_Q, num_total_tokens, stop_event, health_check_fn, args
)
if collated_data is None:
continue
Expand Down
Loading
Loading