-
Notifications
You must be signed in to change notification settings - Fork 3
/
config.py
215 lines (201 loc) · 12.4 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Literal, Optional, List
from transformers import TrainingArguments
class FDivergenceType(Enum):
REVERSE_KL = "reverse_kl"
JS_DIVERGENCE = "js_divergence"
ALPHA_DIVERGENCE = "alpha_divergence"
class FDivergenceConstants:
ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0
@dataclass
class DPOConfig(TrainingArguments):
r"""
Configuration class for the [`DPOTrainer`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
the [paper](https://huggingface.co/papers/2310.12036).
label_smoothing (`float`, *optional*, defaults to `0.0`):
Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and
[Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
Type of loss to use. Possible values are:
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
- `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
- `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
- `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust DPO](https://huggingface.co/papers/2403.00409) paper.
- `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
- `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) paper.
- `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
use_weighting (`bool`, *optional*, defaults to `False`):
Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`Optional[int]`, *optional*, defaults to `None`):
Padding value to use. If `None`, the padding value of the tokenizer is used.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the
default data collator.
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the target. This argument is required if you want to use the default data collator and
your model is an encoder-decoder.
is_encoder_decoder(`Optional[int]`, *optional*, defaults to `None`):
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
you need to specify if the model returned by the callable is an encoder-decoder model.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
useful when training without the reference model to reduce the total GPU memory needed.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
string.
ref_model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
from a string.
model_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
reference_free (`bool`, *optional*, defaults to `False`):
If `True`, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal
probability to all responses.
force_use_ref_model (`bool`, *optional*, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the
ref_model, set this flag to `True`.
f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
Type of f-divergence regularization function to compute divergence between policy and reference model.
f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
α coefficient in the α-divergence u^-α regularization function for DPO loss.
sync_ref_model (`bool`, *optional*, defaults to `False`):
When set to `True`, the reference model is synchronized with the active model every `ref_model_sync_steps`
steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`
To use this parameter, you must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
rpo_alpha (`float`, *optional*, defaults to `None`):
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
DPO loss. The paper recommends `rpo_alpha=1.0`.
"""
learning_rate: float = 1e-6
beta: float = 0.1
label_smoothing: float = 0.0
loss_type: Literal[
"sigmoid",
"hinge",
"ipo",
"exo_pair",
"nca_pair",
"robust",
"bco_pair",
"sppo_hard",
"aot",
"aot_pair",
"apo_zero",
"apo_down",
] = "sigmoid"
use_weighting: bool = False
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_target_length: Optional[int] = None # deprecated in favor of max_completion_length
max_completion_length: Optional[int] = None
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
generate_during_eval: bool = False
precompute_ref_log_probs: bool = False
dataset_num_proc: Optional[int] = None
model_init_kwargs: Optional[Dict[str, Any]] = None
ref_model_init_kwargs: Optional[Dict[str, Any]] = None
model_adapter_name: Optional[str] = None
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
f_divergence_type: FDivergenceType = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: float = 1.0
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None
use_flex_attn: bool = False
prefix_sharing: bool = False
enable_packing: bool = False
max_train_samples: Optional[int] = None
max_eval_samples: Optional[int] = None
packing_length: Optional[int] = None
def __post_init__(self):
if self.max_target_length is not None:
warnings.warn(
"The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be removed in a future version.",
FutureWarning,
)
if self.max_completion_length is None:
self.max_completion_length = self.max_target_length
return super().__post_init__()
@dataclass
class DPOScriptArguments:
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
keep_columns: Optional[List[str]] = field(default=None, metadata={"help": "if specified, which columns of the dataset to use"})
dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to use for training"})
dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to use for evaluation"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False,
metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"},
)