diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 49282e0bac..18b3855fe7 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1546,6 +1546,87 @@ def test_vdpo_trainer(self, model_id): continue self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + def test_dpo_trainer_gemma3_vision_model_detection(self): + """CPU-only check that Gemma 3 routes via vision path and preserves pixel tensors.""" + + # Minimal Gemma 3-like model + class _MinimalConfig: + def __init__(self): + self.model_type = "gemma3" + self.is_encoder_decoder = False + self._name_or_path = "dummy" + + class _MinimalModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = _MinimalConfig() + self.warnings_issued = {} + + model = _MinimalModel() + + # Mock tokenizer with required token IDs to avoid external dependencies + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + mock_tokenizer.eos_token_id = 1 + mock_tokenizer.bos_token_id = 2 + + def _tok_call(text, add_special_tokens=False): + return {"input_ids": [11, 22]} + + mock_tokenizer.side_effect = _tok_call + + # Mock processor that returns pixel_values to simulate vision processing + processor = MagicMock() + processor.tokenizer = mock_tokenizer + # Ensure DPOTrainer reads an integer pad_token_id + processor.pad_token_id = 0 + + def _proc_call(images=None, text=None, add_special_tokens=False): + return { + "input_ids": [[101, 102]], + "pixel_values": [np.zeros((3, 8, 8), dtype=np.float32)], + } + + processor.side_effect = _proc_call + + # Tiny dataset with one 16x16 image + img = Image.fromarray(np.zeros((16, 16, 3), dtype=np.uint8)) + ds = Dataset.from_list( + [ + { + "prompt": "Describe the image.", + "chosen": "Black square.", + "rejected": "White circle.", + "images": [img], + } + ] + ) + # Test-optimized config to avoid multiprocessing and reference model creation + args = DPOConfig( + output_dir=self.tmp_dir, + dataset_num_proc=1, + precompute_ref_log_probs=True, + gradient_checkpointing=False, + report_to="none", + ) + + trainer = DPOTrainer(model=model, args=args, processing_class=processor, train_dataset=ds) + + # 1) Model detected as vision-text + self.assertTrue(trainer.is_vision_model, "Expected Gemma 3 to be detected as vision-text model") + + # 2) Dataset processed via process_row (pixel_values present) + row = trainer.train_dataset[0] + self.assertIn("pixel_values", row, "process_row did not add pixel_values") + + # 3) Signature columns include vision fields + trainer._set_signature_columns_if_needed() + self.assertIn("pixel_values", trainer._signature_columns) + + # 4) Collator preserves pixel tensors + batch = trainer.data_collator([row]) + self.assertIn("pixel_values", batch, "pixel_values missing in collated batch") + if __name__ == "__main__": unittest.main() diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2cacb26e3f..3892bddd26 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -48,7 +48,7 @@ is_mlflow_available, is_wandb_available, ) -from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_liger_kernel_available, is_peft_available @@ -330,7 +330,7 @@ def __init__( ) self.is_encoder_decoder = model.config.is_encoder_decoder - self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() + self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) self.model_adapter_name = args.model_adapter_name self.ref_adapter_name = args.ref_adapter_name @@ -788,6 +788,8 @@ def _set_signature_columns_if_needed(self): "prompt_input_ids", "chosen_input_ids", "rejected_input_ids", + "pixel_values", + "pixel_attention_mask", "image_sizes", "ref_chosen_logps", "ref_rejected_logps",