Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class GRPOConfig(TrainingArguments):
model_init_kwargs (`str`, `dict[str, Any]`, *optional*):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`GRPOTrainer`] is provided as a string.
config_init_kwargs (`str`, `dict[str, Any]`, *optional*):
Keyword arguments for [`~transformers.AutoConfig.from_pretrained`], used when the `config`
argument of the [`GRPOTrainer`] is provided as a string.
disable_dropout (`bool`, *optional*, defaults to `False`):
Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
the model from generating different logprobs for the same input.
Expand Down Expand Up @@ -242,7 +245,7 @@ class GRPOConfig(TrainingArguments):
are logged.
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "config_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
Expand Down Expand Up @@ -279,6 +282,13 @@ class GRPOConfig(TrainingArguments):
"argument of the `GRPOTrainer` is provided as a string."
},
)
config_init_kwargs: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoConfig.from_pretrained`, used when the `config` "
"argument of the `GRPOTrainer` is provided as a string."
},
)
disable_dropout: bool = field(
default=False,
metadata={
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
config_init_kwargs = args.config_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
Expand All @@ -239,7 +240,7 @@ def __init__(
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
config = AutoConfig.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id, **config_init_kwargs)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
Expand Down Expand Up @@ -427,7 +428,7 @@ def __init__(
self.ref_model = None
else:
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
config = AutoConfig.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id, **config_init_kwargs)
architecture = getattr(transformers, config.architectures[0])
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)

Expand Down