Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,14 @@ class GRPOConfig(TrainingArguments):
"all prompts are logged."
},
)
wandb_log_extra_columns: Optional[Union[None, list[str]]] = field(
default=None,
metadata={
"help": "Extra dataset columns to include in W&B logging. If `None` (default), no extra columns are "
"logged. If `[]` (empty list), all available dataset columns are logged. If `['col1', 'col2']`, only "
"the specified columns are logged. Columns that don't exist in the dataset will be ignored."
},
)

def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
Expand Down
59 changes: 59 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def __init__(
self._total_train_tokens = 0
self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
self.wandb_log_extra_columns = args.wandb_log_extra_columns
self.num_completions_to_print = args.num_completions_to_print
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._logs = {
Expand All @@ -469,6 +470,7 @@ def __init__(
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
"advantages": deque(maxlen=args.generation_batch_size),
"extra_columns": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
}

# Ensure each process receives a unique seed to prevent duplicate completions when generating with
Expand Down Expand Up @@ -1067,6 +1069,22 @@ def _generate_and_score_completions(

prompts = [x["prompt"] for x in inputs]

# Collect extra columns if configured
extra_columns_data = {}
if self.wandb_log_extra_columns is not None:
# Determine which columns to collect
if self.wandb_log_extra_columns == []:
# Empty list means collect all available columns except 'prompt' and 'image'
all_keys = set(inputs[0].keys()) if inputs else set()
columns_to_collect = all_keys - {"prompt", "image"}
else:
# Specific list of columns
columns_to_collect = set(self.wandb_log_extra_columns)

# Collect the data from each input
for col in columns_to_collect:
extra_columns_data[col] = [x.get(col) for x in inputs]

# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
# VLM chat template.
Expand Down Expand Up @@ -1157,6 +1175,12 @@ def _generate_and_score_completions(
if has_images:
all_images = gather_object(images)

# Gather extra columns data for vLLM server mode
all_extra_columns_data = {}
if self.wandb_log_extra_columns is not None and extra_columns_data:
for col, values in extra_columns_data.items():
all_extra_columns_data[col] = gather_object(values)

if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
Expand Down Expand Up @@ -1198,6 +1222,12 @@ def _generate_and_score_completions(
completion_ids = completion_ids[process_slice]
all_logprobs = all_logprobs[process_slice]

# Slice extra columns data back to process-specific portion
if self.wandb_log_extra_columns is not None and all_extra_columns_data:
extra_columns_data = {}
for col, all_values in all_extra_columns_data.items():
extra_columns_data[col] = all_values[process_slice]

# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.guided_decoding_regex:
Expand Down Expand Up @@ -1234,9 +1264,18 @@ def _generate_and_score_completions(
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None

# Gather extra columns for tensor parallel
all_extra_columns_data = {}
if self.wandb_log_extra_columns is not None and extra_columns_data:
for col, values in extra_columns_data.items():
gathered_col_values = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_col_values, values, group=self.tp_group)
all_extra_columns_data[col] = [v for sublist in gathered_col_values for v in sublist]
else:
all_prompts_text = prompts_text
all_images = images if has_images else None
all_extra_columns_data = extra_columns_data # No gathering needed for single GPU

if has_images and all_images:
vllm_inputs = []
Expand Down Expand Up @@ -1266,6 +1305,14 @@ def _generate_and_score_completions(
completion_ids = completion_ids[tp_slice]
all_logprobs = all_logprobs[tp_slice]

# Slice extra columns back to this rank's portion
if self.wandb_log_extra_columns is not None and all_extra_columns_data:
extra_columns_data = {}
for col, all_values in all_extra_columns_data.items():
extra_columns_data[col] = all_values[tp_slice]
else:
extra_columns_data = all_extra_columns_data

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)

Expand Down Expand Up @@ -1512,6 +1559,11 @@ def _generate_and_score_completions(
if has_images:
self._logs["image"].extend(gather_object(images))

# Log extra columns if configured
if self.wandb_log_extra_columns is not None and extra_columns_data:
for col, values in extra_columns_data.items():
self._logs["extra_columns"][col].extend(gather_object(values))

if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
delta = delta[completion_mask.bool()]
Expand Down Expand Up @@ -1801,6 +1853,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
else:
table["image"].append(None)

# Add extra columns to the table if configured
if self.wandb_log_extra_columns is not None and self._logs["extra_columns"]:
for col, values in self._logs["extra_columns"].items():
# Ensure the column has the same length as prompts
if values:
table[col] = list(values)

df = pd.DataFrame(table)
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
Expand Down