-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM #7651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
This is probably the right approach. I'll fast follow with my selective_state_update and causal_conv1d_update improvements once it lands, and then we can simplify the Mamba cache |
@tlrmchlsmth Ready for review, I took only the relevant code for inference, removed the boilerplate parts and added the support for providing initial state for mamba_ssm, though not using it in the Jamba modeling file at the moment.
Thanks! |
/ready |
Yeah, it would be fine to do these as a follow up. Or now that there's code, I can paste them into my PR once this one lands. Also, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is still adding 20MB (after compression!) to the wheel size. I see a lot of cases being compiled for the selective_scan_fwd
kernels, so it seems like we should still be able to bring that down.
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { | ||
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { | ||
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { | ||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { | ||
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { | ||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these BOOL_SWITCH
macros are causing more combinatorial kernel blow-up. If reducing the number of type combinations doesn't bring the kernels down to a reasonable size, we could make some of these conditions dynamic if it isn't checked in the innermost loop. We could also try to remove some of the BOOL_SWITCH statements if it's always true or always false. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, took a brief look at the Falcon Mamba transformers's modeling file which is also based on Mamba 1 and it seems that it also uses the same default behaviour as in Jamba as also defined in the mamba repo - mamba_simple.py .
where is_variable_B
, is_variable_C
hasZ
are all true.
Therefore, I've removed those BOOL_SWITCH boolean terms from the kernel creation loop and just set them to be True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wheel size of this PR now is 134MB,
Wheel dist/vllm-0.5.5+cu124-cp38-abi3-linux_x86_64.whl is within the allowed size (134.5479211807251 MB).
upstream is 130MB
Wheel dist/vllm-0.5.5+cu124-cp38-abi3-linux_x86_64.whl is within the allowed size (130.0757646560669 MB).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
constexpr kHasSeqPosIdx = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a TORCH_CHECK(params.seq_pos_idx_ptr == nullptr)`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment that some kernel cases have been disabled to reduce binary size would be good to add for documentation as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, btw, This variable is used for batch varlen enablement, which is out of scope IMO for this PR, I've taken it down completely and will have a seperate following up PR to for varlen batching
constexpr bool kIsVariableB = true; | ||
constexpr bool kIsVariableC = true; | ||
constexpr bool kHasZ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add torch checks to guard against cases where kIsVariableB
, kIsVariableC
, or kHasZ
is false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the failing lora test might be bogus, so I restarted it but it's still failing. Do you think it could be an actual issue with the PR?
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
constexpr kHasSeqPosIdx = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment that some kernel cases have been disabled to reduce binary size would be good to add for documentation as well
It seems like upstream is also failing on the same test https://buildkite.com/vllm/ci-aws/builds/7715#01919725-4d5b-41ef-9392-f88017b2693b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once green, thank you!
CI failures seem to occur in the main branch as well , https://buildkite.com/vllm/ci-aws/builds/7791 |
…llm-project#7651) Signed-off-by: Alvant <[email protected]>
const bool has_z = z_.has_value(); | ||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mzusman , we would like to pass z as None, and hence run into this error.
What do you think is the best way to suport that?
Thanks a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@congcongchen123 Could I ask what you're trying to do? Is this for supporting a new model in vLLM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tlrmchlsmth for quick reply. Yes, we have a new model that is under development. And it reuses the mamba kernel but we would like to allow z to be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@congcongchen123 OK -- support for that case was removed during review to reduce the size of the compiled binaries. Should be pretty easy to restore. Take a look at this commit: abf02fa.
Then pay attention to how that affects the wheel size, once has_z
support is restored!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @tlrmchlsmth !
…llm-project#7651) Signed-off-by: LeiWang1999 <[email protected]>
In order to accelate the developement/optimization of those kernels, I started to migrate the relevant code from mamba_ssm /casusal_conv1d kernels to vLLM.
This is due to several failed attempts to push several improvements to those repos, such as:
Relevant for Jamba/Mamba models - #7428 #4115 #3690 #6484