Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect results from flux.AGKernel for some problem shapes #17

Closed
tlrmchlsmth opened this issue Jul 5, 2024 · 15 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@tlrmchlsmth
Copy link
Contributor

Describe the bug
I am seeing different using torch.distributed.all_gather_into_tensor followed by torch.mm as compared to the flux fused AllGather GEMM. The problem seems to be limited to certain problem shapes.

To Reproduce
Here is my short script to repro the issue. For this rerpo, I am running on 2 A100s, torch 2.3, CUDA 12.5, Flux is at 66e2716, and I'm building Flux with bash build.sh --arch 80 --package and installing the wheel.

import multiprocessing as mp
import torch

import flux

def test(tp_group, m, n, k, transpose_weight, local_copy, dtype):
    world_size = torch.distributed.get_world_size(tp_group)

    a = (-2 * torch.rand(m, k, dtype=dtype, device="cuda") + 1) / .1
    b = ((-2 * torch.rand(k, n, dtype=dtype, device="cuda") + 1) / .1)

    # Massage b depending on transpose_weight
    b_ag_gemm = b
    if not transpose_weight:
        b_ag_gemm = b.t().contiguous()

    # Run fused AllGather Gemm
    ag_gemm_op = flux.AGKernel(tp_group,
                               1,
                               8192,
                               n,
                               k,
                               dtype,
                               dtype,
                               transpose_weight=transpose_weight,
                               local_copy=local_copy)

    torch.distributed.barrier()
    ag_gemm_output = ag_gemm_op.forward(a, b_ag_gemm)

    # Run a torch AllGather followed by a GEMM

    a_gathered = torch.zeros(m * world_size, k, dtype=dtype, device="cuda")
    torch.distributed.all_gather_into_tensor(a_gathered, a, 0)
    torch_output = torch.mm(a_gathered, b)
    torch.distributed.barrier()

    if not torch.allclose(torch_output, ag_gemm_output, atol=1e-1, rtol=1e-1):
        difference = (torch_output - ag_gemm_output).to(dtype=torch.float32)
        print(f"""Error:  Max diff.
        Process: {torch.distributed.get_rank(tp_group)}.
        Arguments: {m}, {n}, {k}, {local_copy}, {transpose_weight}
        Torch output norm: {torch.norm(torch_output.to(dtype=torch.float32))}.
        AGGemm output norm: {torch.norm(ag_gemm_output.to(dtype=torch.float32))},
        Norm of difference: {torch.norm(difference)}""")


@torch.no_grad()
def initialize_process(rank, world_size):
    # Assign GPU to this process
    torch.cuda.set_device(rank)

    # Create a torch communicator
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='tcp://localhost:12345',
        world_size=world_size,
        rank=rank
    )
    tp_group = torch.distributed.new_group(ranks=list(range(world_size)), backend="nccl")

    # Initialize pynvshmem using the torch communicator
    flux.init_flux_shm(tp_group)
    torch.cuda.synchronize()

    # These are OK
    test(tp_group, 16, 4096, 4096, False, False, torch.float16);
    test(tp_group, 16, 6144, 4096, False, False, torch.float16);

    # This produces a wrong answer
    test(tp_group, 16, 3072, 4096, False, False, torch.float16);

    # Clean up
    torch.distributed.destroy_process_group()

def main():
    torch.set_printoptions(precision=4)
    torch.set_printoptions(sci_mode=True)

    world_size = 2  # Number of processes to create
    processes = []

    for rank in range(world_size):
        p = mp.Process(target=initialize_process, args=(rank, world_size))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

if __name__ == "__main__":
    main()
@zheng-ningxin zheng-ningxin added the bug Something isn't working label Jul 6, 2024
@zheng-ningxin zheng-ningxin self-assigned this Jul 6, 2024
@wenlei-bao
Copy link
Collaborator

@tlrmchlsmth It seems I cannot reproduce any numerical issue on our A100 machine with TP=2/8 and other similar settings.
My setting is torch 2.1 and cuda 11.8. cc @zheng-ningxin

@tlrmchlsmth
Copy link
Contributor Author

I'm still running into this with CUDA 11.8 and torch 2.1.

If I use the following Dockerfile (saved as ./Dockerfile)

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

WORKDIR /app

RUN apt-get update \
    && apt-get install -y python3-pip python3-venv git

# Check out a branch with the script from this issue 
# and one change needed for it to run without nvshmem 
RUN git clone https://github.com/tlrmchlsmth/flux
WORKDIR /app/flux
RUN git checkout tms/repro_17

RUN python3 -m venv venv
ENV VIRTUAL_ENV=venv/bin/activate
RUN pip3 install torch==2.1.0 numpy==1.26 cmake ninja wheel

RUN git submodule update --init --recursive
RUN bash build.sh --arch 80 --package
RUN pip install dist/flux-1.0.0+cu118-cp310-cp310-linux_x86_64.whl

ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/app/flux/venv/lib"

CMD ["python3", "run.py"]

and then build and run it with the following (I'm on an 8 A100 system):

docker build -t cuda_11.8.0 .
docker run --gpus 4,5,6,7 -it cuda_11.8.0

I see:

Error:  Max diff.
        Process: 0.
        Arguments: 16, 3072, 4096, False, False
        Torch output norm: 668120.4375.
        AGGemm output norm: 668120.6875,
        Norm of difference: 223.62152099609375Error:  Max diff.
        Process: 1.
        Arguments: 16, 3072, 4096, False, False
        Torch output norm: 666105.1875.
        AGGemm output norm: 666106.5,
        Norm of difference: 222.33139038085938

@zheng-ningxin
Copy link
Collaborator

Thank you very much for your valuable feedback. I was able to reproduce the number difference issue you pointed out.

I discovered that the discrepancy was due to the absence of a reset_signal when calling ag_gemm_op. After adding the reset_signal, the issue was resolved.


def test(tp_group, m, n, k, transpose_weight, local_copy, dtype):
    world_size = torch.distributed.get_world_size(tp_group)

    # a = (-2 * torch.rand(m, k, dtype=dtype, device="cuda") + 1) / .1
    # b = ((-2 * torch.rand(k, n, dtype=dtype, device="cuda") + 1) / .1)
    a = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
    b = torch.rand(k, n, dtype=dtype, device="cuda") - 0.5
    # Massage b depending on transpose_weight
    b_ag_gemm = b
    if not transpose_weight:
        b_ag_gemm = b.t().contiguous()

    # Run fused AllGather Gemm
    ag_gemm_op = flux.AGKernel(tp_group,
                               1,
                               8192,
                               n,
                               k,
                               dtype,
                               dtype,
                               transpose_weight=transpose_weight,
                               local_copy=local_copy)

    torch.distributed.barrier()
    ag_gemm_op.reset_signals() # reset here
    ag_gemm_output = ag_gemm_op.forward(a, b_ag_gemm)

    # Run a torch AllGather followed by a GEMM

    a_gathered = torch.zeros(m * world_size, k, dtype=dtype, device="cuda")
    torch.distributed.all_gather_into_tensor(a_gathered, a, 0)
    torch_output = torch.mm(a_gathered, b)
    torch.distributed.barrier()

    if not torch.allclose(torch_output, ag_gemm_output, atol=1e-1, rtol=1e-1):
        difference = (torch_output - ag_gemm_output).to(dtype=torch.float32)
        print("\n")
        print(torch_output)
        print("\n")
        print(ag_gemm_output)
        print("\n")
        print(difference)
        print(f"""Error:  Max diff.
        Process: {torch.distributed.get_rank(tp_group)}.
        Arguments: {m}, {n}, {k}, {local_copy}, {transpose_weight}
        Torch output norm: {torch.norm(torch_output.to(dtype=torch.float32))}.
        AGGemm output norm: {torch.norm(ag_gemm_output.to(dtype=torch.float32))},
        Norm of difference: {torch.norm(difference)}""")

The interface for AG_gemm is indeed confusing. We are considering whether to embed the reset_signal directly within the C++ implementation rather than exposing it as an interface to the users. While this would simplify the interface, it may result in a performance trade-off since users wouldn’t be able to use other operations to hide the reset_signal overhead.

@tlrmchlsmth
Copy link
Contributor Author

I see. Thanks for the explanation!

I'm assuming that reset_signal basically does a memset. Is that right? How large are the signal buffers? I think this would be a tricky operation to overlap with others during inference as most operations end up being close to bottlenecked by memory bandwidth and they would be using the same resources. Have you seen success overlapping reset_signal in this way?

@tlrmchlsmth
Copy link
Contributor Author

Would it be better for each worker to keep track of the current epoch (number of times the barrier has been used) and use that for the barrier state instead of making sure the barrier is zeroed out every time? That way you wouldn't have to reset the signal buffer.

@wenlei-bao
Copy link
Collaborator

I see. Thanks for the explanation!

I'm assuming that reset_signal basically does a memset. Is that right? How large are the signal buffers? I think this would be a tricky operation to overlap with others during inference as most operations end up being close to bottlenecked by memory bandwidth and they would be using the same resources. Have you seen success overlapping reset_signal in this way?

Right, it is a memset basically. The size depends on M dimension and threadblock size due to the fact that tiles along N will share same signal. But it is quite small. Based on our experience, this operation could be overlapped by opening another stream and put it after AG GEMM. It shouldn't take much of the memory bandwidth.

And moreover, if like the way in the test script, create flux.AGKernel op every time, then there is no need to reset the signal.
The reset op is needed only for multiple runs.
Hope this could make things clear.

@wenlei-bao
Copy link
Collaborator

Would it be better for each worker to keep track of the current epoch (number of times the barrier has been used) and use that for the barrier state instead of making sure the barrier is zeroed out every time? That way you wouldn't have to reset the signal buffer.

Do you mean a counter? It is possible, but would require some changes to the code and kernel as well IIRC. We tried that but end up using the reset method for simplicity.
BTW there are also other tricks to hide this reset operation, such as double buffer etc.

@wenlei-bao
Copy link
Collaborator

Does the fix work for you @tlrmchlsmth ? If so shall we fix this one. cc @zheng-ningxin

@zheng-ningxin
Copy link
Collaborator

I submitted a PR that moves the reset signal to the critical path of the forward function. Users no longer need to manually reset the signal. cc @wenlei-bao @tlrmchlsmth
#19

@tlrmchlsmth
Copy link
Contributor Author

Hey @wenlei-bao and @zheng-ningxin, it does work when manually resetting the signal. However, I see pretty significant slowdown when using so. This may be due to the fact that I'm creating the AGKernel object to support the maximum context length. In this case it's 8192, where m is typically much, much smaller.

Do you think it would be reasonable for performance to create the flux.AGKernel op each time?

@zheng-ningxin
Copy link
Collaborator

Creating an op each time won’t help with performance. The slowdown is due to an unnecessary CUDA stream sync in the reset-signal. I’ve removed it in the new MR(#19). Are there still correctness issues after adding the reset_signal? Also, if possible, could you please push/update the integration codebase of both AG+RS? I’ll work on resolving the correctness/performance issues on my side as well. @tlrmchlsmth Thanks a lot!

@tlrmchlsmth
Copy link
Contributor Author

Adding the reset signal fixes the correctness issues. My integration PR is up-to-date and fuses both GEMM+RS and AG+GEMM --vllm-project/vllm#5917.

@wenlei-bao
Copy link
Collaborator

wenlei-bao commented Jul 15, 2024

Hey @tlrmchlsmth , Thanks for the update. Does the perf number in #17 (comment) include @zheng-ningxin 's fix?

@tlrmchlsmth
Copy link
Contributor Author

@wenlei-bao No, those numbers are stale. I updated the PR's description noting that they are stale. I'll rerun the numbers after trying out the latest fix

@tlrmchlsmth
Copy link
Contributor Author

This is good to close now, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants