Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated args in trainers #2036

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
packing=True,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)
trainer = SFTTrainer(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_dpo(self):
is_encoder_decoder=True,
disable_dropout=False,
# generate_during_eval=True, # ignore this one, it requires wandb
precompute_ref_log_probs=True,
# precompute_ref_log_probs=True, # can't be True if sync_ref_model is True, so we just test sync_ref_model
dataset_num_proc=4,
model_init_kwargs={"trust_remote_code": True},
ref_model_init_kwargs={"trust_remote_code": True},
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_dpo(self):
self.assertEqual(trainer.args.is_encoder_decoder, True)
self.assertEqual(trainer.args.disable_dropout, False)
# self.assertEqual(trainer.args.generate_during_eval, True)
self.assertEqual(trainer.args.precompute_ref_log_probs, True)
# self.assertEqual(trainer.args.precompute_ref_log_probs, True)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
Expand Down
163 changes: 4 additions & 159 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from huggingface_hub.utils._deprecation import _deprecate_arguments
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -392,80 +391,28 @@ class DPOTrainer(Trainer):

_tag_names = ["trl", "dpo"]

@_deprecate_arguments(
version="1.0.0",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is the reason to remove these deprecated args due to the fact we've had this warning for several releases?

I wonder if it's better to pin the version to 2-3 future versions to give people a concrete warning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, these arguments have been deprecated in #1554, in v0.9.3.

Sounds good, let's wait a bit.

deprecated_args=[
"beta",
"label_smoothing",
"loss_type",
"label_pad_token_id",
"padding_value",
"truncation_mode",
"max_length",
"max_prompt_length",
"max_target_length",
"is_encoder_decoder",
"disable_dropout",
"generate_during_eval",
"precompute_ref_log_probs",
"dataset_num_proc",
"model_init_kwargs",
"ref_model_init_kwargs",
"model_adapter_name",
"ref_adapter_name",
"reference_free",
"force_use_ref_model",
],
custom_message="Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.",
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Optional[str] = None,
args: Optional[DPOConfig] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: Optional[int] = None,
truncation_mode: str = "keep_end",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
max_target_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
is_encoder_decoder: Optional[bool] = None,
disable_dropout: bool = True,
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: Optional[int] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
reference_free: bool = False,
force_use_ref_model: bool = False,
):
if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if model_init_kwargs is not None:
warnings.warn(
"You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.model_init_kwargs = model_init_kwargs

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
Expand All @@ -485,12 +432,6 @@ def __init__(
)
model_init_kwargs["torch_dtype"] = torch_dtype

if ref_model_init_kwargs is not None:
warnings.warn(
"You passed `ref_model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.ref_model_init_kwargs = ref_model_init_kwargs

if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
Expand Down Expand Up @@ -528,12 +469,6 @@ def __init__(
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False

if force_use_ref_model:
warnings.warn(
"You passed `force_use_ref_model` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.force_use_ref_model = force_use_ref_model

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
Expand Down Expand Up @@ -595,22 +530,12 @@ def make_inputs_require_grad(module, input, output):

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

if generate_during_eval:
warnings.warn(
"You passed `generate_during_eval` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.generate_during_eval = generate_during_eval
if args.generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)

if is_encoder_decoder is not None:
warnings.warn(
"You passed `is_encoder_decoder` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.is_encoder_decoder = is_encoder_decoder
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
Expand All @@ -635,33 +560,10 @@ def make_inputs_require_grad(module, input, output):
self.tokenizer = tokenizer

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
if model_adapter_name is not None:
warnings.warn(
"You passed `model_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.model_adapter_name = model_adapter_name
self.model_adapter_name = args.model_adapter_name

if ref_adapter_name is not None:
warnings.warn(
"You passed `ref_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.ref_adapter_name = ref_adapter_name
self.ref_adapter_name = args.ref_adapter_name

if reference_free:
warnings.warn(
"You passed `reference_free` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.reference_free = reference_free
self.reference_free = args.reference_free

if precompute_ref_log_probs:
warnings.warn(
"You passed `precompute_ref_log_probs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.precompute_ref_log_probs = precompute_ref_log_probs

if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
Expand All @@ -673,11 +575,6 @@ def make_inputs_require_grad(module, input, output):
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")

if max_length is not None:
warnings.warn(
"You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_length = max_length
if args.max_length is None:
warnings.warn(
"`max_length` is not set in the DPOConfig's init"
Expand All @@ -686,11 +583,6 @@ def make_inputs_require_grad(module, input, output):
)
args.max_length = 512

if max_prompt_length is not None:
warnings.warn(
"You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_prompt_length = max_prompt_length
if args.max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the DPOConfig's init"
Expand All @@ -699,11 +591,6 @@ def make_inputs_require_grad(module, input, output):
)
args.max_prompt_length = 128

if max_target_length is not None:
warnings.warn(
"You passed `max_target_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_completion_length = max_target_length
if args.max_completion_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_completion_length` in the DPOConfig's init"
Expand All @@ -712,11 +599,6 @@ def make_inputs_require_grad(module, input, output):
)
args.max_completion_length = 128

if label_pad_token_id != -100:
warnings.warn(
"You passed `label_pad_token_id` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_pad_token_id = label_pad_token_id
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=self.tokenizer.pad_token_id,
Expand All @@ -737,11 +619,6 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

if not disable_dropout:
warnings.warn(
"You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.disable_dropout = disable_dropout
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand All @@ -750,18 +627,8 @@ def make_inputs_require_grad(module, input, output):
self.max_length = args.max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
if padding_value is not None:
warnings.warn(
"You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.padding_value = padding_value
self.padding_value = args.padding_value if padding_value is not None else self.tokenizer.pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else self.tokenizer.pad_token_id
self.max_prompt_length = args.max_prompt_length
if truncation_mode != "keep_end":
warnings.warn(
"You passed `truncation_mode` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.truncation_mode = truncation_mode
self.truncation_mode = args.truncation_mode
self.max_completion_length = args.max_completion_length
self.precompute_ref_log_probs = args.precompute_ref_log_probs
Expand All @@ -771,16 +638,6 @@ def make_inputs_require_grad(module, input, output):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False

if loss_type is not None:
warnings.warn(
"You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.loss_type = loss_type
if label_smoothing != 0:
warnings.warn(
"You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_smoothing = label_smoothing
if (
args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
and args.label_smoothing > 0
Expand All @@ -791,11 +648,6 @@ def make_inputs_require_grad(module, input, output):
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")

if beta != 0.1:
warnings.warn(
"You passed `beta` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.beta = beta
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
Expand All @@ -806,13 +658,6 @@ def make_inputs_require_grad(module, input, output):
self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}

if dataset_num_proc is not None:
warnings.warn(
"You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.dataset_num_proc = dataset_num_proc
self.dataset_num_proc = args.dataset_num_proc

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand All @@ -838,7 +683,7 @@ def make_inputs_require_grad(module, input, output):
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=self.dataset_num_proc,
num_proc=args.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing train dataset",
)
Expand All @@ -847,7 +692,7 @@ def make_inputs_require_grad(module, input, output):
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=self.dataset_num_proc,
num_proc=args.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing eval dataset",
)
Expand Down Expand Up @@ -898,7 +743,7 @@ def make_inputs_require_grad(module, input, output):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if args.sync_ref_model:
if precompute_ref_log_probs:
if args.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)
Expand Down
Loading
Loading