Skip to content

[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM #7651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Aug 28, 2024

Conversation

mzusman
Copy link
Contributor

@mzusman mzusman commented Aug 19, 2024

In order to accelate the developement/optimization of those kernels, I started to migrate the relevant code from mamba_ssm /casusal_conv1d kernels to vLLM.

This is due to several failed attempts to push several improvements to those repos, such as:

Relevant for Jamba/Mamba models - #7428 #4115 #3690 #6484

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@tlrmchlsmth
Copy link
Collaborator

This is probably the right approach. I'll fast follow with my selective_state_update and causal_conv1d_update improvements once it lands, and then we can simplify the Mamba cache

@mzusman mzusman changed the title [Model][Do not merge] Migrate mamba_ssm and causal_conv1d kernels to vLLM [Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM Aug 22, 2024
@mzusman mzusman marked this pull request as ready for review August 22, 2024 15:11
@mzusman
Copy link
Contributor Author

mzusman commented Aug 22, 2024

@tlrmchlsmth Ready for review, I took only the relevant code for inference, removed the boilerplate parts and added the support for providing initial state for mamba_ssm, though not using it in the Jamba modeling file at the moment.
Plans for future PRs on my side :

  • Take out the mamba layer from the Jamba modeling file and moving it out the a layer file in the model_executor/layers
  • Add support for prefill chunking for Jamba/Mamba
  • Add support for varlen prefill batching for Jamba/Mamba
  • Further optimizations in the mamba ssm kernels

Thanks!

@mzusman
Copy link
Contributor Author

mzusman commented Aug 22, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 22, 2024
@bnellnm
Copy link
Contributor

bnellnm commented Aug 26, 2024

register meta functions to the kernels

Yeah, it would be fine to do these as a follow up. Or now that there's code, I can paste them into my PR once this one lands. Also, the opcheck from #6917 is just a convenience wrapper around torch.library.opcheck and isn't strictly necessary.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is still adding 20MB (after compression!) to the wheel size. I see a lot of cases being compiled for the selective_scan_fwd kernels, so it seems like we should still be able to bring that down.

Comment on lines 314 to 319
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these BOOL_SWITCH macros are causing more combinatorial kernel blow-up. If reducing the number of type combinations doesn't bring the kernels down to a reasonable size, we could make some of these conditions dynamic if it isn't checked in the innermost loop. We could also try to remove some of the BOOL_SWITCH statements if it's always true or always false. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, took a brief look at the Falcon Mamba transformers's modeling file which is also based on Mamba 1 and it seems that it also uses the same default behaviour as in Jamba as also defined in the mamba repo - mamba_simple.py .
where is_variable_B, is_variable_C hasZ are all true.
Therefore, I've removed those BOOL_SWITCH boolean terms from the kernel creation loop and just set them to be True.

Copy link
Contributor Author

@mzusman mzusman Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wheel size of this PR now is 134MB,

Wheel dist/vllm-0.5.5+cu124-cp38-abi3-linux_x86_64.whl is within the allowed size (134.5479211807251 MB).

upstream is 130MB

Wheel dist/vllm-0.5.5+cu124-cp38-abi3-linux_x86_64.whl is within the allowed size (130.0757646560669 MB).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
constexpr kHasSeqPosIdx = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a TORCH_CHECK(params.seq_pos_idx_ptr == nullptr)`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that some kernel cases have been disabled to reduce binary size would be good to add for documentation as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, btw, This variable is used for batch varlen enablement, which is out of scope IMO for this PR, I've taken it down completely and will have a seperate following up PR to for varlen batching

Comment on lines +314 to +316
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add torch checks to guard against cases where kIsVariableB, kIsVariableC, or kHasZ is false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the failing lora test might be bogus, so I restarted it but it's still failing. Do you think it could be an actual issue with the PR?

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
constexpr kHasSeqPosIdx = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that some kernel cases have been disabled to reduce binary size would be good to add for documentation as well

@mzusman
Copy link
Contributor Author

mzusman commented Aug 28, 2024

I thought the failing lora test might be bogus, so I restarted it but it's still failing. Do you think it could be an actual issue with the PR?

It seems like upstream is also failing on the same test https://buildkite.com/vllm/ci-aws/builds/7715#01919725-4d5b-41ef-9392-f88017b2693b

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once green, thank you!

@mzusman
Copy link
Contributor Author

mzusman commented Aug 28, 2024

CI failures seem to occur in the main branch as well , https://buildkite.com/vllm/ci-aws/builds/7791

@simon-mo simon-mo merged commit fdd9daa into vllm-project:main Aug 28, 2024
52 of 57 checks passed
tlrmchlsmth added a commit to neuralmagic/nm-vllm that referenced this pull request Aug 29, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Comment on lines +550 to +551
const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mzusman , we would like to pass z as None, and hence run into this error.
What do you think is the best way to suport that?

Thanks a lot!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@congcongchen123 Could I ask what you're trying to do? Is this for supporting a new model in vLLM?

Copy link
Contributor

@congcongchen123 congcongchen123 Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tlrmchlsmth for quick reply. Yes, we have a new model that is under development. And it reuses the mamba kernel but we would like to allow z to be None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@congcongchen123 OK -- support for that case was removed during review to reduce the size of the compiled binaries. Should be pretty easy to restore. Take a look at this commit: abf02fa.

Then pay attention to how that affects the wheel size, once has_z support is restored!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @tlrmchlsmth !

@mergify mergify bot added the ci/build label Nov 19, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants