-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multi-gpu fused_moe tuning support #143
base: main
Are you sure you want to change the base?
Conversation
- todo: add fp8 support - todo: add comments and documentation
cd68e07
to
37fc500
Compare
retest |
# For now we see better perf with num_stages=0 for all gemm configs we care | ||
# But keep this explicit so that we do not forget we may need to set it to | ||
# other values in the future | ||
num_stage_range = [0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this PR comes at a tricky time unfortunately.
See triton-lang/triton#4845 (review)
We are changing the sw pipelining such that it is more aligned with the nvidia side. After the above PR is merged (imminently), num_stages = 0 will actually be num_stages=2. If it is kept at 0, it will fail with an error as mentioned in the above link.
So we have two options
-
Hold this PR, if possible, until the Triton PR is submitted.
-
Submit this and then submit another once things break. As mentioned, the error message should be clear and should say what to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shajrawi FYI
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
tuning_utils.py
torchrun benchmark_mixtral_moe_rocm.py --model 8x7B --modelTP 8 --numGPU 8