-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Incorrect results from flux.AGKernel for some problem shapes #17
Comments
@tlrmchlsmth It seems I cannot reproduce any numerical issue on our A100 machine with TP=2/8 and other similar settings. |
I'm still running into this with CUDA 11.8 and torch 2.1. If I use the following Dockerfile (saved as
and then build and run it with the following (I'm on an 8 A100 system):
I see:
|
Thank you very much for your valuable feedback. I was able to reproduce the number difference issue you pointed out. I discovered that the discrepancy was due to the absence of a reset_signal when calling ag_gemm_op. After adding the reset_signal, the issue was resolved.
The interface for AG_gemm is indeed confusing. We are considering whether to embed the reset_signal directly within the C++ implementation rather than exposing it as an interface to the users. While this would simplify the interface, it may result in a performance trade-off since users wouldn’t be able to use other operations to hide the reset_signal overhead. |
I see. Thanks for the explanation! I'm assuming that |
Would it be better for each worker to keep track of the current epoch (number of times the barrier has been used) and use that for the barrier state instead of making sure the barrier is zeroed out every time? That way you wouldn't have to reset the signal buffer. |
Right, it is a memset basically. The size depends on M dimension and threadblock size due to the fact that tiles along N will share same signal. But it is quite small. Based on our experience, this operation could be overlapped by opening another stream and put it after AG GEMM. It shouldn't take much of the memory bandwidth. And moreover, if like the way in the test script, create flux.AGKernel op every time, then there is no need to reset the signal. |
Do you mean a counter? It is possible, but would require some changes to the code and kernel as well IIRC. We tried that but end up using the reset method for simplicity. |
Does the fix work for you @tlrmchlsmth ? If so shall we fix this one. cc @zheng-ningxin |
I submitted a PR that moves the reset signal to the critical path of the forward function. Users no longer need to manually reset the signal. cc @wenlei-bao @tlrmchlsmth |
Hey @wenlei-bao and @zheng-ningxin, it does work when manually resetting the signal. However, I see pretty significant slowdown when using so. This may be due to the fact that I'm creating the Do you think it would be reasonable for performance to create the |
Creating an op each time won’t help with performance. The slowdown is due to an unnecessary CUDA stream sync in the reset-signal. I’ve removed it in the new MR(#19). Are there still correctness issues after adding the reset_signal? Also, if possible, could you please push/update the integration codebase of both AG+RS? I’ll work on resolving the correctness/performance issues on my side as well. @tlrmchlsmth Thanks a lot! |
Adding the reset signal fixes the correctness issues. My integration PR is up-to-date and fuses both GEMM+RS and AG+GEMM --vllm-project/vllm#5917. |
Hey @tlrmchlsmth , Thanks for the update. Does the perf number in #17 (comment) include @zheng-ningxin 's fix? |
@wenlei-bao No, those numbers are stale. I updated the PR's description noting that they are stale. I'll rerun the numbers after trying out the latest fix |
This is good to close now, thanks! |
Describe the bug
I am seeing different using
torch.distributed.all_gather_into_tensor
followed bytorch.mm
as compared to the flux fused AllGather GEMM. The problem seems to be limited to certain problem shapes.To Reproduce
Here is my short script to repro the issue. For this rerpo, I am running on 2 A100s, torch 2.3, CUDA 12.5, Flux is at 66e2716, and I'm building Flux with
bash build.sh --arch 80 --package
and installing the wheel.The text was updated successfully, but these errors were encountered: