Skip to content

Commit

Permalink
Push changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Feb 19, 2025
1 parent 11521d1 commit 30a55df
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ def vllm_generate(
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
if accelerator.is_main_process:
learning_traces = []
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))

# for _ in range(1, resume_training_step): # we didn't store scheduler state
Expand Down Expand Up @@ -1071,6 +1072,11 @@ def vllm_generate(
]
g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device)
g_vllm_responses[:] = g_padded_response_ids
current_reasnoning_traces = {
"queries": remove_padding(global_queries, tokenizer.pad_token_id),
"responses": g_response_token_ids,
}

dist.broadcast(g_vllm_responses, src=0)
local_vllm_responses = g_vllm_responses[
accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0]
Expand Down Expand Up @@ -1159,6 +1165,11 @@ def vllm_generate(
gc.collect()
torch.cuda.empty_cache()

global_scores = dist.all_gather(scores)
if accelerator.is_main_process:
current_reasnoning_traces["scores"] = global_scores
learning_traces.append(current_reasnoning_traces)

# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
Expand Down Expand Up @@ -1340,6 +1351,9 @@ def vllm_generate(
self.save_model(self.model, step_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
self.launch_ai2_evals_on_weka(step_dir, training_step)
# save `learning_traces`
with open(os.path.join(step_dir, "learning_traces.json"), "w") as f:
json.dump(learning_traces, f)
print(f"Saving final model at step {training_step} to {args.output_dir}")
self.save_model(self.model, args.output_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
Expand Down

0 comments on commit 30a55df

Please sign in to comment.