-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Out of Memory Error: DPO Trainer #2452
Comments
You have two gpus, but you only use it 1 in your accelerate config. You could also use deepspeed to further decrease the memory footprint. Lastly, keep per_device_train_batch_size as low as possible, instead increase gradient_accumulation step. |
It might come from your data. Do you have long sequences in your dataset? DPOConfig(
...,
max_prompt_length=128,
max_completion_length=512,
) |
@gp-1108 I faced similar issues. I would recommend to check available modules in your cluster by a command like "module avail" and load a cuda installation by "module load", of course this is assuming you are in slurm env. If you dont have cuda in available modules, perhaps you could ask cluster admins to download it. I think you should be good after this. |
Hi all, I have finally fixed all of the CUDA issues with the computing cluster 😮💨. I have tweaked both the script and the accelerate config so I will leave them below (I hope everything is setup as it should be). TRL ENV:
SCRIPT: import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel, LoraConfig
from trl import DPOConfig, DPOTrainer
import utils as ut
import torch
from accelerate import Accelerator
import os
os.environ['WANDB_DISABLED'] = 'true'
#import wandb
def print_memory_usage(description="Memory Usage"):
"""
Prints the current memory usage for all available GPU devices.
Args:
description (str): A short description for context.
"""
if torch.cuda.is_available():
print(f"{description}:")
for i in range(torch.cuda.device_count()):
device = f"cuda:{i}"
free_mem, total_mem = torch.cuda.mem_get_info(device)
used_mem = total_mem - free_mem
total_mem_mb = total_mem / 1024**2 # Convert to MB
free_mem_mb = free_mem / 1024**2 # Convert to MB
used_mem_mb = used_mem / 1024**2 # Convert to MB
print(f" Device: {device}")
print(f" Total Memory: {total_mem_mb:.2f} MB")
print(f" Used Memory: {used_mem_mb:.2f} MB")
print(f" Free Memory: {free_mem_mb:.2f} MB")
else:
print("CUDA is not available on this system.")
def main(args):
"""
wandb.init(
# set the wandb project where this run will be logged
project="my-awesome-project",
)
"""
accelerator = Accelerator(
mixed_precision="no",
gradient_accumulation_steps=args.gradient_acc,
)
print(args)
print_memory_usage(description="Before anything")
# Load dataset
print("Loading dataset...")
dataset = ut.load_dataset(args.dataset_path)
dataset = dataset.train_test_split(test_size=args.test_split)
# Load PEFT configuration
print(f"Loading PEFT model configuration from {args.peft_model_id}...")
config = PeftConfig.from_pretrained(args.peft_model_id)
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Load base model
print(f"Loading base model from {config.base_model_name_or_path}...")
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
quantization_config=bnb_config,
trust_remote_code=True, # Hardcoded
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
model.enable_input_require_grads() # To avoid error https://github.com/huggingface/trl/issues/731
print_memory_usage(description="After model init")
# Load tokenizer
print(f"Loading tokenizer from {config.base_model_name_or_path}...")
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.eos_token = "<|eot_id|>" # Hardcoded
tokenizer.pad_token = "<|finetune_right_pad_id|>" # Hardcoded
# Load PEFT model
print(f"Loading PEFT model from {args.peft_model_id}...")
model = PeftModel.from_pretrained(
model,
args.peft_model_id,
adapter_name="trainable",
is_trainable=True
)
model.load_adapter(args.peft_model_id, adapter_name="reference") # Hardcoded
print_memory_usage(description="After two adapters")
tokenizer.chat_template = None
# Configure training arguments
training_args = DPOConfig(
learning_rate=args.learning_rate,
beta=args.beta,
loss_type=args.loss_type,
use_weighting=args.use_weighting,
rpo_alpha=args.rpo_alpha,
output_dir=args.output_dir,
logging_steps=args.logging_steps,
model_adapter_name="trainable", # Hardcoded
ref_adapter_name="reference", # Hardcoded
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_acc,
)
# Configure Lora
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'lm_head']
)
# Initialize DPO trainer
print("Initializing DPO trainer...")
dpo_trainer = DPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
)
# Prepare everything for training
model, tokenizer, train_dataset, eval_dataset = accelerator.prepare(
model, tokenizer, dataset["train"], dataset["test"]
)
# Train the model
print("Starting training...")
dpo_trainer.train()
print("Training complete.")
dpo_trainer.save_model()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Fine-tune a model using PEFT and DPOTrainer.")
parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset file (JSONL).")
parser.add_argument("--test_split", type=float, default=0.15, help="Proportion of dataset to use for testing.")
parser.add_argument("--peft_model_id", type=str, required=True, help="Path to the PEFT model directory.")
parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization.")
parser.add_argument("--output_dir", type=str, default="Llama31_DPO", help="Directory to save the trained model.")
parser.add_argument("--logging_steps", type=int, default=1, help="Number of steps for logging during training.")
parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate for the AdamW optimizer.")
parser.add_argument("--beta", type=float, default=0.1, help="Parameter controlling deviation from the reference model.")
parser.add_argument("--loss_type", type=str, default="sigmoid", help="Type of loss to use for training.")
parser.add_argument("--use_weighting", action="store_true", help="Enable weighting of the loss.")
parser.add_argument("--rpo_alpha", type=float, default=None, help="Alpha parameter for the RPO paper.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training per gpu.")
parser.add_argument("--gradient_acc", type=int, default=1, help="Gradient accumulation steps.")
args = parser.parse_args()
main(args) The script crashes after being called with the following parameters: accelerate launch --num_processes=2 --num_machines=1 --mixed_precision=no --dynamo_backend=inductor dpo_finetuning.py \
--dataset_path ../dataset_generation/data/dpo_dialogues.jsonl \
--peft_model_id ../llama3.1_finetuning/output/llama3.1_SFT_from_Base/checkpoint-800 \
--output_dir ./tmp \
--logging_steps 1 \
--load_in_8bit \
--batch_size 1 \
--gradient_acc 1 The full traceback is this: (sorry for the duplication, it is two processes)
STACK TRACE TLDR: I think that @qgallouedec might be onto something, as my prompts and responses are quite lenghty. I have noticed that when pre-processing the dataset the trainer will add a crazy amount of padding tokens also. NOTE: I cannot afford to truncate the samples' text, as it is critical to have sometimes those lengthy prompts+answer pairs during training. |
Hi, I have solved the issue finally and I am going to leave it here for the posterity. The issue lay mainly in two things:
MANAGING SAMPLE LENGTH:
You can clearly see that in some cases we get up to 6k length. This is perhaps not ideal. Afterwards, the maximum length was 2k which is a manageable. PEFT CONFIGURATION:
Even though my peft configuration did not include the embedding layer in the targets. peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'lm_head']
) I resorted to the good old peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'lm_head']
)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
quantization_config=bnb_config,
trust_remote_code=True, # Hardcoded
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
model.enable_input_require_grads() # To avoid error https://github.com/huggingface/trl/issues/731
model = PeftModel.from_pretrained(
model,
args.peft_model_id,
adapter_name="trainable",
is_trainable=True
)
model.load_adapter(args.peft_model_id, adapter_name="reference")
model = get_peft_model(model, peft_config) Also avoiding the OTHER IMPROVEMENTS: |
System Info
MACHINE SETUP:
TRL ENV:
ACCELERATE SETUP:
Information
Tasks
examples
folderReproduction
outputs:
Please discard the size of the dataset, as I am testing with a small subset of it.
Expected behavior
I am encountering some difficulties in training llama 3.1 8B SFT with Lora.
Basically I cannot increase the batch size over 2 samples per gpu, even though I am using almost 100GB combined (48GB for each A40).
What bugs me is that even if I try to use crazy approximations in the LoraConfig and narrow the target modules the output will be the same: Out Of Memory. Even when I was not using Lora I did get the same results.
The only thing I accomplished was pushing it from 1 to 2 samples per device using
accelerate launch --num_processes=1
, but the results are still far from desirable.My question therefore is the following: Is DPO just a really heavy kind of training? I didn't think it would greatly differ from SFT but here I am throwing at it 4 times as much VRAM and nowhere close to the same batch size.
Also, am I configuring Lora correctly? Changing the hyperparameters does not affect memory consumption at all (even removing the LoraConfig does not change anything).
I have even loaded two adapters on the same model to save some VRAM and I am starting to wonder if I am doing anything wrong at all.
Checklist
The text was updated successfully, but these errors were encountered: