Skip to content

Commit

Permalink
Log validation generations to wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
corbt committed Jan 31, 2025
1 parent 679798c commit d68d481
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ trainer:
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
Expand Down
78 changes: 76 additions & 2 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,76 @@ def _create_dataloader(self):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps

def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores):
"""Log a table of validation samples to wandb"""

generations_to_log = self.config.trainer.val_generations_to_log_to_wandb

if generations_to_log == 0:
return

import wandb
import numpy as np

# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text

# Use fixed random seed for deterministic shuffling
rng = np.random.RandomState(42)
rng.shuffle(samples)

# Take first N samples after shuffling
samples = samples[:generations_to_log]

# Create column names for all samples
columns = ["step"] + sum([[
f"input_{i+1}",
f"output_{i+1}",
f"score_{i+1}"
] for i in range(len(samples))], [])

if not hasattr(self, 'validation_table'):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)

# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)

# Add new row with all data
row_data = []
row_data.append(self.global_steps)
for sample in samples:
row_data.extend(sample)

new_table.add_data(*row_data)

# Update reference and log
wandb.log({"val/generations": new_table}, step=self.global_steps)
self.validation_table = new_table

def _validate(self):
reward_tensor_lst = []
data_source_lst = []

# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# test_batch = test_batch.to('cuda')

# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
return {}

# Store original inputs
input_ids = test_batch.batch['input_ids']
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)

test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
Expand All @@ -421,17 +480,32 @@ def _validate(self):
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print('validation generation end')

# Store generated outputs
output_ids = test_output_gen_batch.batch['responses']
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)

test_batch = test_batch.union(test_output_gen_batch)

# evaluate using reward_function
# for certain reward function (e.g. sandbox), the generation can overlap with reward
reward_tensor = self.val_reward_fn(test_batch)

# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)

reward_tensor_lst.append(reward_tensor)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))

self._maybe_log_val_generations_to_wandb(
inputs=sample_inputs,
outputs=sample_outputs,
scores=sample_scores
)

reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
data_sources = np.concatenate(data_source_lst, axis=0)

# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
Expand Down

0 comments on commit d68d481

Please sign in to comment.