diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 9c4f4b3a686..9cd43b7e3f2 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -635,6 +635,14 @@ class GRPOConfig(TrainingArguments): "all prompts are logged." }, ) + wandb_log_extra_columns: Optional[Union[None, list[str]]] = field( + default=None, + metadata={ + "help": "Extra dataset columns to include in W&B logging. If `None` (default), no extra columns are " + "logged. If `[]` (empty list), all available dataset columns are logged. If `['col1', 'col2']`, only " + "the specified columns are logged. Columns that don't exist in the dataset will be ignored." + }, + ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b3e1c716cd1..e8c09475a28 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -461,6 +461,7 @@ def __init__( self._total_train_tokens = 0 self.log_completions = args.log_completions self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.wandb_log_extra_columns = args.wandb_log_extra_columns self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { @@ -469,6 +470,7 @@ def __init__( "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), "advantages": deque(maxlen=args.generation_batch_size), + "extra_columns": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), } # Ensure each process receives a unique seed to prevent duplicate completions when generating with @@ -1067,6 +1069,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] + # Collect extra columns if configured + extra_columns_data = {} + if self.wandb_log_extra_columns is not None: + # Determine which columns to collect + if self.wandb_log_extra_columns == []: + # Empty list means collect all available columns except 'prompt' and 'image' + all_keys = set(inputs[0].keys()) if inputs else set() + columns_to_collect = all_keys - {"prompt", "image"} + else: + # Specific list of columns + columns_to_collect = set(self.wandb_log_extra_columns) + + # Collect the data from each input + for col in columns_to_collect: + extra_columns_data[col] = [x.get(col) for x in inputs] + # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the # VLM chat template. @@ -1157,6 +1175,12 @@ def _generate_and_score_completions( if has_images: all_images = gather_object(images) + # Gather extra columns data for vLLM server mode + all_extra_columns_data = {} + if self.wandb_log_extra_columns is not None and extra_columns_data: + for col, values in extra_columns_data.items(): + all_extra_columns_data[col] = gather_object(values) + if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate @@ -1198,6 +1222,12 @@ def _generate_and_score_completions( completion_ids = completion_ids[process_slice] all_logprobs = all_logprobs[process_slice] + # Slice extra columns data back to process-specific portion + if self.wandb_log_extra_columns is not None and all_extra_columns_data: + extra_columns_data = {} + for col, all_values in all_extra_columns_data.items(): + extra_columns_data[col] = all_values[process_slice] + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": if self.guided_decoding_regex: @@ -1234,9 +1264,18 @@ def _generate_and_score_completions( all_images = [img for sublist in gathered_images for img in sublist] else: all_images = None + + # Gather extra columns for tensor parallel + all_extra_columns_data = {} + if self.wandb_log_extra_columns is not None and extra_columns_data: + for col, values in extra_columns_data.items(): + gathered_col_values = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_col_values, values, group=self.tp_group) + all_extra_columns_data[col] = [v for sublist in gathered_col_values for v in sublist] else: all_prompts_text = prompts_text all_images = images if has_images else None + all_extra_columns_data = extra_columns_data # No gathering needed for single GPU if has_images and all_images: vllm_inputs = [] @@ -1266,6 +1305,14 @@ def _generate_and_score_completions( completion_ids = completion_ids[tp_slice] all_logprobs = all_logprobs[tp_slice] + # Slice extra columns back to this rank's portion + if self.wandb_log_extra_columns is not None and all_extra_columns_data: + extra_columns_data = {} + for col, all_values in all_extra_columns_data.items(): + extra_columns_data[col] = all_values[tp_slice] + else: + extra_columns_data = all_extra_columns_data + if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1512,6 +1559,11 @@ def _generate_and_score_completions( if has_images: self._logs["image"].extend(gather_object(images)) + # Log extra columns if configured + if self.wandb_log_extra_columns is not None and extra_columns_data: + for col, values in extra_columns_data.items(): + self._logs["extra_columns"][col].extend(gather_object(values)) + if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) delta = delta[completion_mask.bool()] @@ -1801,6 +1853,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non else: table["image"].append(None) + # Add extra columns to the table if configured + if self.wandb_log_extra_columns is not None and self._logs["extra_columns"]: + for col, values in self._logs["extra_columns"].items(): + # Ensure the column has the same length as prompts + if values: + table[col] = list(values) + df = pd.DataFrame(table) if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=["prompt"])