Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Llama-SwiftKV model #11023

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

aurickq
Copy link
Contributor

@aurickq aurickq commented Dec 9, 2024

SwiftKV was recently announced at https://www.snowflake.com/engineering-blog/swiftkv-llm-compute-reduction/. This PR adds a SwiftKV version of Llama that can immediately be used to run models at https://huggingface.co/collections/Snowflake/swiftkv-models-674f7d7474eb789e185d31cb.

The model definition is somewhat unconventional due to the need to early-exit some tokens but not others after a specific number of layers, and we wanted to minimize the amount of changes to vLLM's core code. Specifically:

  1. sampling_params is passed into the model forward pass so it can identify which tokens to early exit and which tokens to propagate through all layers. Only those tokens that needs to be sampled from are propagated.
  2. SwiftKV captures and replays its own cuda graph for the second half of the layers, which can have a small batch size even during prefill. We observe a 10-20% throughput gain from this, and it should be fully compatible with vLLM's existing cuda graph that only applies to decode-only batches.
  3. SwiftKV builds its own attention metadata for flash-attention and calls flash-attention directly. This is because the attention metadata for the second half of layers is different from the first half. Additionally to support cuda graph for the second half of the layers, we need a path that calls flash_attn_with_kvcache for mixed prefill-decode batches, which appears not supported in vLLM's flash attention path (only flash_attn_with_varlen, which is not cuda-graphable, is called for any batch that contains prefill tokens).

Current limitations:

  • Only compatible when chunked prefill is enabled.
  • Only compatible with flash-attention.

Copy link

github-actions bot commented Dec 9, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@jikunshang
Copy link
Contributor

Appreciate your great work! If I understand correct, swiftKV is technology which could apply to other models(but needs re-train or finetune based on new model structure), not only limited on llama model, right?
if so, I think a better approach is to define some classes like SwiftKVModelRunner/ SwiftKVAttentionOp to handle stuffs in llama_swiftkv.py

@simon-mo
Copy link
Collaborator

We will likely pick this up after refactoring of the memory management layer in V1 to support the varying KV cache requirements. In V1, we also have piece wise CUDA graph (no graph capture on attention) which should allow you no longer needing to self-manage CUDA Graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants