diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 9c4f4b3a68..af8ea9fd5a 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -37,6 +37,9 @@ class GRPOConfig(TrainingArguments): model_init_kwargs (`str`, `dict[str, Any]`, *optional*): Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` argument of the [`GRPOTrainer`] is provided as a string. + config_init_kwargs (`str`, `dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoConfig.from_pretrained`], used when the `config` + argument of the [`GRPOTrainer`] is provided as a string. disable_dropout (`bool`, *optional*, defaults to `False`): Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents the model from generating different logprobs for the same input. @@ -242,7 +245,7 @@ class GRPOConfig(TrainingArguments): are logged. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "config_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( @@ -279,6 +282,13 @@ class GRPOConfig(TrainingArguments): "argument of the `GRPOTrainer` is provided as a string." }, ) + config_init_kwargs: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoConfig.from_pretrained`, used when the `config` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) disable_dropout: bool = field( default=False, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b3e1c716cd..02ddd2ef2f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -225,6 +225,7 @@ def __init__( # Models # Trained model model_init_kwargs = args.model_init_kwargs or {} + config_init_kwargs = args.config_init_kwargs or {} if isinstance(model, str): model_id = model dtype = model_init_kwargs.get("dtype") @@ -239,7 +240,7 @@ def __init__( f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) - config = AutoConfig.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id, **config_init_kwargs) architecture = getattr(transformers, config.architectures[0]) model = architecture.from_pretrained(model_id, **model_init_kwargs) else: @@ -427,7 +428,7 @@ def __init__( self.ref_model = None else: # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id, **config_init_kwargs) architecture = getattr(transformers, config.architectures[0]) self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)