-
Notifications
You must be signed in to change notification settings - Fork 63
Description
🐛 Describe the bug
Description:
Explicitly calling torch.xpu.cache() in the training process introduces memory leak. Here I used llama3.1-8B as example, printed memory from xpu-smi.
[Rank 0] Epoch 0, Step 0, Loss: 13.1864, Memo:, [Epoch 0, Step 0] xpu-smi memory used (device 0): 21.30 GB [Rank 1] Epoch 0, Step 0, Loss: 12.5021, Memo:, [Epoch 0, Step 0] xpu-smi memory used (device 0): 21.30 GB [Rank 0] Epoch 0, Step 1, Loss: 12.8508, Memo:, [Epoch 0, Step 1] xpu-smi memory used (device 0): 27.11 GB [Rank 1] Epoch 0, Step 1, Loss: 12.6161, Memo:, [Epoch 0, Step 1] xpu-smi memory used (device 0): 27.11 GB [Rank 0] Epoch 0, Step 2, Loss: 13.1110, Memo:, [Epoch 0, Step 2] xpu-smi memory used (device 0): 32.41 GB [Rank 1] Epoch 0, Step 2, Loss: 12.5568, Memo:, [Epoch 0, Step 2] xpu-smi memory used (device 0): 32.41 GB [Rank 1] Epoch 0, Step 3, Loss: 12.4648, Memo:, [Epoch 0, Step 3] xpu-smi memory used (device 0): 38.21 GB [Rank 0] Epoch 0, Step 3, Loss: 12.5307, Memo:, [Epoch 0, Step 3] xpu-smi memory used (device 0): 38.21 GB [Rank 1] Epoch 0, Step 4, Loss: 12.0303, Memo:, [Epoch 0, Step 4] xpu-smi memory used (device 0): 43.51 GB [Rank 0] Epoch 0, Step 4, Loss: 12.3255, Memo:, [Epoch 0, Step 4] xpu-smi memory used (device 0): 43.51 GB [Rank 0] Epoch 0, Step 5, Loss: 11.9331, Memo:, [Epoch 0, Step 5] xpu-smi memory used (device 0): 49.31 GB [Rank 1] Epoch 0, Step 5, Loss: 12.3114, Memo:, [Epoch 0, Step 5] xpu-smi memory used (device 0): 49.31 GB [Rank 0] Epoch 0, Step 6, Loss: 12.0390, Memo:, [Epoch 0, Step 6] xpu-smi memory used (device 0): 54.61 GB [Rank 1] Epoch 0, Step 6, Loss: 12.1755, Memo:, [Epoch 0, Step 6] xpu-smi memory used (device 0): 54.61 GB [Rank 0] Epoch 0, Step 7, Loss: 11.9479, Memo:, [Epoch 0, Step 7] xpu-smi memory used (device 0): 60.41 GB [Rank 1] Epoch 0, Step 7, Loss: 11.9087, Memo:, [Epoch 0, Step 7] xpu-smi memory used (device 0): 60.41 GB [Rank 0] Epoch 0, Step 8, Loss: 11.5645, Memo:, [Epoch 0, Step 8] xpu-smi memory used (device 0): 65.72 GB [Rank 1] Epoch 0, Step 8, Loss: 11.2949, Memo:, [Epoch 0, Step 8] xpu-smi memory used (device 0): 65.72 GB [Rank 0] Epoch 0, Step 9, Loss: 10.8658, Memo:, [Epoch 0, Step 9] xpu-smi memory used (device 0): 71.51 GB [Rank 1] Epoch 0, Step 9, Loss: 11.6554, Memo:, [Epoch 0, Step 9] xpu-smi memory used (device 0): 71.51 GB
Steps to reproduce:
pip install peft, transformerspip install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu- run this reproducer.
torchrun --nproc_per_node=2 reproducer.py
import os
import gc
import sys
import json
import torch
import subprocess
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from torch.distributed._composable.fsdp import fully_shard
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Auto device selection
if torch.xpu.is_available():
torch_device = torch.xpu
backend = "xpu:xccl"
device_type = "xpu"
elif torch.cuda.is_available():
torch_device = torch.cuda
backend = "nccl"
device_type = "cuda"
else:
torch_device = torch.device("cpu")
backend = "gloo"
device_type = "cpu"
def clean_cache():
"""
Clean up the XPU or CUDA cache to free up memory.
"""
try:
if torch.xpu.is_available():
torch.xpu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
else:
print("No XPU or CUDA device available to clear cache.")
except Exception as e:
print(f"Failed to clear device cache: {e}")
# try:
# collected = gc.collect()
# print(f"GC collected {collected} unreachable objects.")
# except Exception as e:
# print(f"Failed to run garbage collection: {e}")
# Memory print (XPU)
def get_xpu_memory_used_from_xpu_smi(tag, device_id=0):
if device_type != "xpu":
return
try:
result = subprocess.run(
["xpu-smi", "stats", "-d", str(device_id), "-j"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
stats = json.loads(result.stdout)
tile_level = stats.get("tile_level", [])
total_mem_mb = sum(metric["value"] for tile in tile_level for metric in tile.get("data_list", []) if metric["metrics_type"] == "XPUM_STATS_MEMORY_USED")
return (f"[{tag}] xpu-smi memory used (device {device_id}): {total_mem_mb / 1024:.2f} GB")
except Exception as e:
return (f"[{tag}] xpu-smi error (device {device_id}): {e}")
# Memory print (CUDA)
def get_cuda_memory_used_from_nvidia_smi(tag, device_id=0):
if device_type != "cuda":
return
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader", "-i", str(device_id)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
used_mem_mb = float(result.stdout.strip())
return(f"[{tag}] nvidia-smi memory used (device {device_id}): {used_mem_mb / 1024:.2f} GB")
except Exception as e:
return(f"[{tag}] nvidia-smi error (device {device_id}): {e}")
def setup_distributed():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend, rank=rank, world_size=world_size)
if device_type in ["xpu", "cuda"]:
torch_device.set_device(rank)
return rank, torch.device(f"{device_type}:{rank}")
return rank, torch_device
def format_prompt(example):
return f"""### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"""
def tokenize_prompt(example, tokenizer, max_length=512):
prompt = format_prompt(example)
tokenized = tokenizer(prompt, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
tokenized["labels"] = tokenized["input_ids"].clone()
return {k: v.squeeze(0) for k, v in tokenized.items()}
def main():
rank, device = setup_distributed()
model_path = "/home/songhappy/models/Meta-Llama-3.1-8B-Instruct/"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
print(f"[Rank {rank}] Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map={"": rank}
)
lora_cfg = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_cfg)
model.to(torch.bfloat16)
for name, module in reversed(list(model.named_modules())):
if isinstance(module, LlamaDecoderLayer):
fully_shard(module)
fully_shard(model)
raw_dataset = load_dataset("tatsu-lab/alpaca", split="train[:20]")
tokenized_dataset = raw_dataset.map(lambda x: tokenize_prompt(x, tokenizer))
sampler = DistributedSampler(tokenized_dataset, num_replicas=dist.get_world_size(), rank=rank, shuffle=True)
dataloader = DataLoader(tokenized_dataset, sampler=sampler, batch_size=1, collate_fn=default_data_collator)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
model.train()
for epoch in range(1):
sampler.set_epoch(epoch)
for step, batch in enumerate(dataloader):
inputs = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(**inputs)
clean_cache()
loss = outputs.loss
loss.backward()
clean_cache()
optimizer.step()
clean_cache()
memo = (
get_xpu_memory_used_from_xpu_smi(f"Epoch {epoch}, Step {step}")
if device_type == "xpu"
else get_cuda_memory_used_from_nvidia_smi(f"Epoch {epoch}, Step {step}")
)
print(f"[Rank {rank}] Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}, Memo:, {memo}")
dist.destroy_process_group()
if __name__ == "__main__":
main()Versions
<style> </style>| pytorch | 4e19477196547eb2e8157d6d132689373ffcf0fa | |
|---|---|---|
| xpu-ops | aa5b3dc |