Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallel performance is not better than eager mode. #36222

Open
2 of 4 tasks
jiqing-feng opened this issue Feb 17, 2025 · 11 comments
Open
2 of 4 tasks

Tensor Parallel performance is not better than eager mode. #36222

jiqing-feng opened this issue Feb 17, 2025 · 11 comments
Labels

Comments

@jiqing-feng
Copy link
Contributor

jiqing-feng commented Feb 17, 2025

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- `transformers` version: 4.48.3
- Platform: Linux-4.18.0-425.3.1.el8.x86_64-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: 1.3.0
- Accelerate config:    - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: bf16
        - use_cpu: False
        - debug: False
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: 5,6
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
- PyTorch version (GPU?): 2.6.0a0+ecf3bae40a.nv25.01 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA A100 80GB PCIe

docker image: nvcr.io/nvidia/pytorch:25.01-py3
Hardware: Nvidia A100

Who can help?

@SunMarc @ArthurZucker @kwen2501

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

CMD: CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node 4 run_tp_hf.py

import os
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Initialize distributed, only TP model needed.
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
print(rank)
print(device)
torch.distributed.init_process_group("nccl", device_id=device)

# Retrieve tensor parallel model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    tp_plan="auto",
    # device_map="cuda:0",
    torch_dtype=torch.float16
)
print(model.dtype)

# Prepare input tokens
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help" * 200
inputs = tokenizer(prompt, return_tensors="pt", max_length=512).input_ids.to(model.device)
print(f"inpu shape is {inputs.shape}")

model = torch.compile(model)
# warm-up
for i in range(100):
    outputs = model(inputs)


torch.cuda.synchronize(device)
# Distributed run
for i in range(50):
    start = time.time()
    torch.cuda.synchronize(device)
    outputs = model(inputs)
    torch.cuda.synchronize(device)
    end = time.time()
    print(f"time cost {(end-start)*1000} ms")

Expected behavior

Latency Performance (ms):
tp_size is world_size

| tp_size | latency | memory per device |
|    1    |  47 ms  |  21.5 G           |
|    2    |  49 ms  |  27 G             |
|    4    |  45 ms  |  27 G             |

The speed-up is not expected as doc claimed.

Related PR: 34184

@kwen2501
Copy link
Contributor

Do you mean multi-GPU performance is not better than single GPU?
In some case, that is possible, especially if the system has a slower interconnect, because TP would introduce inter-GPU communication like all-reduce. This overhead, when large, can offset the computation speedup.

From your report, I see that your system uses PCI-e instead of NVLinks:

  • GPU type: NVIDIA A100 80GB PCIe

In case of slow interconnect, I would recommend using Pipeline Parallel (PP) instead of Tensor Parallel because it is better at comm latency hiding. And it will increase the system throughput by the number of GPUs, of course, in ideal situation.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 17, 2025

Hi @kwen2501 . Thanks for your clarification. I guess so, just want to know if I missed anything that cannot reproduce the speed-up in your docs here. Is it only because of the hardware?

Image

@kwen2501
Copy link
Contributor

kwen2501 commented Feb 17, 2025

Yeah, the above benchmark is from a 8x H100 machine with fully connected NVLinks.

Your benchmark script may not capture the actual time needed though.

# Distributed run
for i in range(50):
    start = time.time()
    outputs = model(inputs)
    end = time.time()
    print(f"time cost {(end-start)*1000} ms")

The time between start and end is only the CPU time -- to launch CUDA kernels. It does not include the time for CUDA kernels to finish the computation.

@kwen2501
Copy link
Contributor

You'd need to

torch.cuda.synchronize(device)

before recording the end timestamp.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 17, 2025

Hi @kwen2501 . Thanks for your reminding. I have updated the script and performance table, still can't get the acceleration. It would be great if you could give me more details about with fully connected NVLinks you mentioned before. Do I need any ENV parameters or extra codes to enable fully connected NVLinks?

@kwen2501
Copy link
Contributor

kwen2501 commented Feb 17, 2025

If you type $ nvidia-smi topo -m in your command line, you would see your machine's GPU topology.

For example:

Image

where NV stands for NVLinks.

Other legends may be:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

That's entirely hardware, not something controlled by ENV.

@jiqing-feng
Copy link
Contributor Author

Image
Yeah, it's the main difference, my device is PIX while yours is NV#. Thanks for your help!

@ArthurZucker
Copy link
Collaborator

Thanks @kwen2501 for helping! 🤗 I believe we can close this now? 🤗

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 19, 2025

Hi @kwen2501 . I found a new issue. The model weight shape remains the same despite the tp size.

I add

shape = model.model.layers[0].self_attn.q_proj.weight.shape
print(f"weight shape is {shape}")

after loading model in my script. The output is always weight shape is torch.Size([4096, 4096]) no matter the TP size. It also reflected on the memory per device.
My command: CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc-per-node 4 run_tp_hf.py.

I suppose each card only saves a part of this model if we running TP. Is this as expected?

@kwen2501
Copy link
Contributor

kwen2501 commented Feb 19, 2025

You are seeing the same shape as before because DTensor.shape is designed to return the original shape rather than the sharded shape. I agree that there may be some confusion here, and we should document it better.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 20, 2025

You are seeing the same shape as before because DTensor.shape is designed to return the original shape rather than the sharded shape. I agree that there may be some confusion here, and we should document it better.

Hi @kwen2501 . Do you know how to check the sharded shape? I want to know which model part on each GPU to avoid unbalanced split.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants