Skip to content

torch.xpu.empty_cache() introduces memory leak #1745

@songhappy

Description

@songhappy

🐛 Describe the bug

Description:

Explicitly calling torch.xpu.cache() in the training process introduces memory leak. Here I used llama3.1-8B as example, printed memory from xpu-smi.

[Rank 0] Epoch 0, Step 0, Loss: 13.1864, Memo:, [Epoch 0, Step 0] xpu-smi memory used (device 0): 21.30 GB [Rank 1] Epoch 0, Step 0, Loss: 12.5021, Memo:, [Epoch 0, Step 0] xpu-smi memory used (device 0): 21.30 GB [Rank 0] Epoch 0, Step 1, Loss: 12.8508, Memo:, [Epoch 0, Step 1] xpu-smi memory used (device 0): 27.11 GB [Rank 1] Epoch 0, Step 1, Loss: 12.6161, Memo:, [Epoch 0, Step 1] xpu-smi memory used (device 0): 27.11 GB [Rank 0] Epoch 0, Step 2, Loss: 13.1110, Memo:, [Epoch 0, Step 2] xpu-smi memory used (device 0): 32.41 GB [Rank 1] Epoch 0, Step 2, Loss: 12.5568, Memo:, [Epoch 0, Step 2] xpu-smi memory used (device 0): 32.41 GB [Rank 1] Epoch 0, Step 3, Loss: 12.4648, Memo:, [Epoch 0, Step 3] xpu-smi memory used (device 0): 38.21 GB [Rank 0] Epoch 0, Step 3, Loss: 12.5307, Memo:, [Epoch 0, Step 3] xpu-smi memory used (device 0): 38.21 GB [Rank 1] Epoch 0, Step 4, Loss: 12.0303, Memo:, [Epoch 0, Step 4] xpu-smi memory used (device 0): 43.51 GB [Rank 0] Epoch 0, Step 4, Loss: 12.3255, Memo:, [Epoch 0, Step 4] xpu-smi memory used (device 0): 43.51 GB [Rank 0] Epoch 0, Step 5, Loss: 11.9331, Memo:, [Epoch 0, Step 5] xpu-smi memory used (device 0): 49.31 GB [Rank 1] Epoch 0, Step 5, Loss: 12.3114, Memo:, [Epoch 0, Step 5] xpu-smi memory used (device 0): 49.31 GB [Rank 0] Epoch 0, Step 6, Loss: 12.0390, Memo:, [Epoch 0, Step 6] xpu-smi memory used (device 0): 54.61 GB [Rank 1] Epoch 0, Step 6, Loss: 12.1755, Memo:, [Epoch 0, Step 6] xpu-smi memory used (device 0): 54.61 GB [Rank 0] Epoch 0, Step 7, Loss: 11.9479, Memo:, [Epoch 0, Step 7] xpu-smi memory used (device 0): 60.41 GB [Rank 1] Epoch 0, Step 7, Loss: 11.9087, Memo:, [Epoch 0, Step 7] xpu-smi memory used (device 0): 60.41 GB [Rank 0] Epoch 0, Step 8, Loss: 11.5645, Memo:, [Epoch 0, Step 8] xpu-smi memory used (device 0): 65.72 GB [Rank 1] Epoch 0, Step 8, Loss: 11.2949, Memo:, [Epoch 0, Step 8] xpu-smi memory used (device 0): 65.72 GB [Rank 0] Epoch 0, Step 9, Loss: 10.8658, Memo:, [Epoch 0, Step 9] xpu-smi memory used (device 0): 71.51 GB [Rank 1] Epoch 0, Step 9, Loss: 11.6554, Memo:, [Epoch 0, Step 9] xpu-smi memory used (device 0): 71.51 GB

Steps to reproduce:

  1. pip install peft, transformers
  2. pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu
  3. run this reproducer.
    torchrun --nproc_per_node=2 reproducer.py
import os
import gc
import sys
import json
import torch
import subprocess
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from torch.distributed._composable.fsdp import fully_shard
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


# Auto device selection
if torch.xpu.is_available():
    torch_device = torch.xpu
    backend = "xpu:xccl"
    device_type = "xpu"
elif torch.cuda.is_available():
    torch_device = torch.cuda
    backend = "nccl"
    device_type = "cuda"
else:
    torch_device = torch.device("cpu")
    backend = "gloo"
    device_type = "cpu"

def clean_cache():
    """
    Clean up the XPU or CUDA cache to free up memory.
    """
    try:
        if torch.xpu.is_available():
            torch.xpu.empty_cache()
        elif torch.cuda.is_available():
            torch.cuda.empty_cache()
        else:
            print("No XPU or CUDA device available to clear cache.")
    except Exception as e:
        print(f"Failed to clear device cache: {e}")

    # try:
    #     collected = gc.collect()
    #     print(f"GC collected {collected} unreachable objects.")
    # except Exception as e:
    #     print(f"Failed to run garbage collection: {e}")

# Memory print (XPU)
def get_xpu_memory_used_from_xpu_smi(tag, device_id=0):
    if device_type != "xpu":
        return
    try:
        result = subprocess.run(
            ["xpu-smi", "stats", "-d", str(device_id), "-j"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        stats = json.loads(result.stdout)
        tile_level = stats.get("tile_level", [])
        total_mem_mb = sum(metric["value"] for tile in tile_level for metric in tile.get("data_list", []) if metric["metrics_type"] == "XPUM_STATS_MEMORY_USED")
        return (f"[{tag}] xpu-smi memory used (device {device_id}): {total_mem_mb / 1024:.2f} GB")
    except Exception as e:
        return (f"[{tag}] xpu-smi error (device {device_id}): {e}")

# Memory print (CUDA)
def get_cuda_memory_used_from_nvidia_smi(tag, device_id=0):
    if device_type != "cuda":
        return
    try:
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader", "-i", str(device_id)],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        used_mem_mb = float(result.stdout.strip())
        return(f"[{tag}] nvidia-smi memory used (device {device_id}): {used_mem_mb / 1024:.2f} GB")
    except Exception as e:
        return(f"[{tag}] nvidia-smi error (device {device_id}): {e}")


def setup_distributed():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    if device_type in ["xpu", "cuda"]:
        torch_device.set_device(rank)
        return rank, torch.device(f"{device_type}:{rank}")
    return rank, torch_device


def format_prompt(example):
    return f"""### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"""


def tokenize_prompt(example, tokenizer, max_length=512):
    prompt = format_prompt(example)
    tokenized = tokenizer(prompt, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
    tokenized["labels"] = tokenized["input_ids"].clone()
    return {k: v.squeeze(0) for k, v in tokenized.items()}


def main():
    rank, device = setup_distributed()

    model_path = "/home/songhappy/models/Meta-Llama-3.1-8B-Instruct/"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    print(f"[Rank {rank}] Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map={"": rank}
    )

    lora_cfg = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    model = get_peft_model(model, lora_cfg)
    model.to(torch.bfloat16)


    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, LlamaDecoderLayer):
            fully_shard(module)
    fully_shard(model)
   

    raw_dataset = load_dataset("tatsu-lab/alpaca", split="train[:20]")
    tokenized_dataset = raw_dataset.map(lambda x: tokenize_prompt(x, tokenizer))

    sampler = DistributedSampler(tokenized_dataset, num_replicas=dist.get_world_size(), rank=rank, shuffle=True)
    dataloader = DataLoader(tokenized_dataset, sampler=sampler, batch_size=1, collate_fn=default_data_collator)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

    model.train()
    for epoch in range(1):
        sampler.set_epoch(epoch)
        for step, batch in enumerate(dataloader):
            inputs = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(**inputs)
            clean_cache()
            loss = outputs.loss
            loss.backward()
            clean_cache()
            optimizer.step()
            clean_cache()

            memo = (
                get_xpu_memory_used_from_xpu_smi(f"Epoch {epoch}, Step {step}")
                if device_type == "xpu"
                else get_cuda_memory_used_from_nvidia_smi(f"Epoch {epoch}, Step {step}")
            )
            print(f"[Rank {rank}] Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}, Memo:, {memo}")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Versions

<style> </style>
pytorch 4e19477196547eb2e8157d6d132689373ffcf0fa  
xpu-ops aa5b3dc  

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions