Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluation with OnlineDPO #2464

Open
7 of 9 tasks
MohamedAliRashad opened this issue Dec 11, 2024 · 1 comment · May be fixed by #2476
Open
7 of 9 tasks

Evaluation with OnlineDPO #2464

MohamedAliRashad opened this issue Dec 11, 2024 · 1 comment · May be fixed by #2476
Labels
🐛 bug Something isn't working 🏋 Online DPO Related to Online DPO

Comments

@MohamedAliRashad
Copy link

System Info

Copy-paste the following information when reporting an issue:

  • Platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3
  • Transformers version: 4.46.0
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • Datasets version: 2.21.0
  • HF Hub version: 0.26.3
  • TRL version: 0.12.1
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: 0.0.2
  • OpenAI version: 1.57.1
  • PEFT version: not installed

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification

# Load model to be trained
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

# Load reward model
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")

# Load dataset
ds = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
ds = ds.train_test_split(test_size=0.1, seed=42)

training_args = OnlineDPOConfig(
    output_dir=str(Path(__file__).parent / f"online_dpo_checkpoints"),
    num_train_epochs=10,
    overwrite_output_dir=True,
    eval_strategy="steps",
    save_strategy="steps",
    report_to="none",
    save_steps=5,
    eval_steps=5,
    # weight_decay=0.01,   # Add weight decay
    # warmup_steps=2,    # Add warmup
    gradient_checkpointing=True,  # Great for memory saving
    gradient_checkpointing_kwargs={"use_reentrant": False},  # To remove the warning
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=2,
    save_total_limit=1,
    bf16=True,
    logging_steps=1,
    dataloader_num_workers=8,  # Use multiple processes for data loading
    dataloader_pin_memory=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    ddp_find_unused_parameters=False,  # Don't know what this does
    remove_unused_columns=False,
    max_new_tokens=128,
    missing_eos_penalty=1.0,
    max_grad_norm=1.0,
    eval_delay=0.1,
    optim="adamw_torch_fused",
)

trainer = OnlineDPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_model=reward_model,
    reward_processing_class=reward_tokenizer,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    # data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

outputs:

    raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
ValueError: You must specify exactly one of input_ids or inputs_embeds

This problem happens when i go into evaluation

Expected behavior

Evaluation should work

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 Online DPO Related to Online DPO labels Dec 13, 2024
@qgallouedec qgallouedec linked a pull request Dec 13, 2024 that will close this issue
5 tasks
@haimianxing
Copy link

I also encountered this error.
Traceback (most recent call last): | 200/8754 [1:54:41<68:19:46, 28.76s/it]
File "/mnt/data2/zcz/infer/utils/./accelerate_dpo_ol.py", line 340, in
trainer.train()
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/accelerate/utils/memory.py", line 153, in decorator
return function(batch_size, *args, **kwargs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/trl/trainer/online_dpo_trainer.py", line 639, in _maybe_log_save_evaluate
metrics = self._evaluate(trial, ignore_keys_for_eval)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2958, in _evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 3975, in evaluate
output = eval_loop(
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 4169, in evaluation_loop
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 4395, in prediction_step
outputs = model(**inputs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
outputs = self.model(
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 830, in forward
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
ValueError: You must specify exactly one of input_ids or inputs_embeds
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/data2/zcz/infer/utils/./accelerate_dpo_ol.py", line 340, in
[rank0]: trainer.train()
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2123, in train
[rank0]: return inner_training_loop(
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/accelerate/utils/memory.py", line 153, in decorator
[rank0]: return function(batch_size, *args, **kwargs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
[rank0]: self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/trl/trainer/online_dpo_trainer.py", line 639, in _maybe_log_save_evaluate
[rank0]: metrics = self._evaluate(trial, ignore_keys_for_eval)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 2958, in _evaluate
[rank0]: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 3975, in evaluate
[rank0]: output = eval_loop(
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 4169, in evaluation_loop
[rank0]: losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/trainer.py", line 4395, in prediction_step
[rank0]: outputs = model(**inputs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
[rank0]: outputs = self.model(
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/mnt/data2/zcz/.miniconda3_14/envs/_torch_env/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 830, in forward
[rank0]: raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
[rank0]: ValueError: You must specify exactly one of input_ids or inputs_embeds
wandb: 🚀 View run /mnt/data2/zcz/my_code/keras3/models/Qwen2-7B-DPO_OL at: https://wandb.ai/2474842866/huggingface/runs/2j0wkeig
wandb: Find logs at: wandb/run-20250115_150223-2j0wkeig/logs
[rank0]:[W115 16:58:07.464844890 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
W0115 16:58:12.551361 93286 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 93503 closing signal SIGTERM
W0115 16:58:12.552577 93286 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 93505 closing signal SIGTERM
W0115 16:58:12.552794 93286 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 93506 closing signal SIGTERM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 Online DPO Related to Online DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants