Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leakage with USP and Transformer Blocks #112

Open
baifanxxx opened this issue Dec 12, 2024 · 7 comments
Open

Memory Leakage with USP and Transformer Blocks #112

baifanxxx opened this issue Dec 12, 2024 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@baifanxxx
Copy link

baifanxxx commented Dec 12, 2024

Hi,

First of all, great work on the project! However, I’ve encountered an issue with memory release when using USP. Specifically, I’m using USP for end-to-end sequence parallelism outside the multi-layer Transformer blocks. After processing all Transformer blocks, the final output is gathered via all_gather. Here is a simplified version of the code:

def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    local_hidden_states = hidden_states.chunk(world_size, dim=0)[rank].detach().clone()
    local_hidden_states = self.patch_embed(local_hidden_states)
    rotary_pos_emb = self.rot_pos_emb(grid_thw)
    local_rotary_pos_emb = rotary_pos_emb.chunk(world_size, dim=0)[rank].detach().clone()

    for blk in self.blocks:
        local_hidden_states = blk(local_hidden_states, rotary_pos_emb=local_rotary_pos_emb)

    S, D = local_hidden_states.shape[:]
    hidden_states_gather = torch.zeros(world_size * S, D, dtype=local_hidden_states.dtype, device=local_hidden_states.device)
    dist.all_gather_into_tensor(hidden_states_gather, local_hidden_states)

    return hidden_states_gather

However, I’ve noticed that the GPU memory usage keeps accumulating over time and isn’t properly released. I can provide a pickle file with memory statistics, which can be viewed on [PyTorch Memory Visualization](https://pytorch.ac.cn/memory_viz). The pickle file is
gpu_mem.zip

Upon analysis, I observed that memory created by torch.empty on line 94 in ring/utils.py cannot be released properly. Additionally, several operations like tensor.to(dtype), all_to_all, and others also seem to have issues with memory not being freed. I suspect that this may be related to the use of USP, rather than being a problem with any single operation.

If you have any insights or suggestions that could help resolve this issue, I would greatly appreciate it!

Thanks!

@baifanxxx
Copy link
Author

Issue Description:

I’ve been testing the model on an A100 40GB GPU with the following configuration:

  • Input sequence length: 114464
  • Hidden size: 1280
  • sp_ulysses_degree = 1
  • sp_ring_degree = 2

This setup leads to GPU memory statistics as described, but I encountered a memory release issue. When I switch the values to sp_ulysses_degree = 2 and sp_ring_degree = 1, the memory issue disappears, and everything works fine. I can provide the pickle file from this scenario if needed.
114464_e2e_u2r1_mem.zip

However, regardless of whether I use ulysses or ring parallelism, when facing very long input sequences like the one above (length = 114464), the inference time is significantly slower compared to when sequence parallelism is not used. I would like to discuss this issue with the authors, as it seems related to the way USP (Universal Sequence Parallelism) handles extremely long sequence inputs.

Have you tested inference performance with long sequences like this? It would be great to understand how this issue can be addressed or optimized for very large input sequences.

@feifeibear
Copy link
Owner

feifeibear commented Dec 13, 2024

Thank you for your insightful analysis. Indeed, we have previously encountered similar memory leak issues, and this time I will attempt to improve functions like torch.empty.

As a temporary workaround, setting use_sync=True can eliminate the memory leak. However, it is reported to damage the performance on some GPUs.
https://github.com/feifeibear/long-context-attention/blob/main/yunchang/hybrid/attn_layer.py#L22

@feifeibear feifeibear self-assigned this Dec 13, 2024
@feifeibear feifeibear added the bug Something isn't working label Dec 13, 2024
@baifanxxx
Copy link
Author

Thank you for your thoughtful response. Setting use_sync=True does indeed temporarily address the memory leak issue; however, it introduces additional latency due to increased synchronization overhead. We hope to explore more efficient solutions to tackle the memory issue effectively. Once again, thank you for your attention and valuable feedback.

@feifeibear
Copy link
Owner

Thank you for your thoughtful response. Setting use_sync=True does indeed temporarily address the memory leak issue; however, it introduces additional latency due to increased synchronization overhead. We hope to explore more efficient solutions to tackle the memory issue effectively. Once again, thank you for your attention and valuable feedback.

I will spend some time on the memory issue. If you have any progress, feel free to continue the discussion in this issue, and you're also welcome to submit a PR.

@feifeibear
Copy link
Owner

@baifanxxx You applied USP in training or inference only?

@baifanxxx
Copy link
Author

Only in inference.

@feifeibear
Copy link
Owner

Could please try this solution? I hardly build a test script to reproduce the memory leak issue. Maybe it existing when applied with other communications. For example allgather in your code?

TORCH_NCCL_AVOID_RECORD_STREAMS=1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants