-
Notifications
You must be signed in to change notification settings - Fork 308
wip MoE refactor #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
wip MoE refactor #2600
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2600
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 11 New Failures, 1 Cancelled JobAs of commit d41d7b9 with merge base 0e00df3 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This one replaces #2325, right? I'm struggling to run the Would you mind finding As a side note, it seems that |
can you run it with batch_size 1? i'll try the fix yeah i haven't done the quantization dispatch stuff yet. |
"""Configuration for applying quantization to MoE | ||
Args: | ||
`base_config`: normal AO Config | ||
class DummyModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better solution is to make torchao APIs work on parameters. The current workaround is fine for prototype, but we'd want more proper support for non-prototype.
@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||
# T'(e) tokens for expert e | |||
|
|||
|
|||
class MOEFeedForwardAOQuantizable(nn.Module): | |||
class MoEFeedForwardAOQuantizable(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems unlikely that people are going to swap their MoE module to AO's version. Can we just target torch._grouped_mm
calls directly without requiring a module swap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would it mean to "target" it specifically? If given model compiled, the compiled version of this operator will be used anyway, not sure what else torchao could do about it...
Nope, with both batch_size 1 and 8, it runs out of memory. |
Summary: now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before. DONE -implement moe with grouped_mm [x] -add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x] TODO -add dispatch from grouped_mm to linear decomposition of quantized kernel -compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition -compare linear decomposition vs new linear decomposition for quantized kernels -add scaled_group_gemm and fbgemm kernel (probably in a new PR) ISSUE: the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why? see new_run.log output, first 2 runs are fine, line 144 is nonsense Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags:
Summary:
now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before.
DONE
-implement moe with grouped_mm [x]
-add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x]
TODO
-add dispatch from grouped_mm to linear decomposition of quantized kernel
-compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition
-compare linear decomposition vs new linear decomposition for quantized kernels
-add scaled_group_gemm and fbgemm kernel (probably in a new PR)
ISSUE:
the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why?
see new_run.log output, first 2 runs are fine, line 144 is nonsense
Test Plan:
sh run.sh
Reviewers:
Subscribers:
Tasks:
Tags: