From 052f52746b3a8c532a0c05361bf9ad1f3c456e2a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 3 Dec 2024 10:42:29 -0500 Subject: [PATCH 01/12] implement llama-swiftkv --- vllm/model_executor/models/llama_swiftkv.py | 880 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 2 + vllm/transformers_utils/configs/__init__.py | 2 + .../configs/llama_swiftkv.py | 31 + vllm/worker/model_runner.py | 4 + 6 files changed, 920 insertions(+) create mode 100644 vllm/model_executor/models/llama_swiftkv.py create mode 100644 vllm/transformers_utils/configs/llama_swiftkv.py diff --git a/vllm/model_executor/models/llama_swiftkv.py b/vllm/model_executor/models/llama_swiftkv.py new file mode 100644 index 0000000000000..f19a0ae91b828 --- /dev/null +++ b/vllm/model_executor/models/llama_swiftkv.py @@ -0,0 +1,880 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.vllm_flash_attn import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_with_kvcache, +) +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaMLP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, is_pp_missing_parameter, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import LlamaSwiftKVConfig + + +@dataclass +class SwiftKVMetadata: + use_varlen: bool + indices: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] + + # non-varlen args + seq_lens: Optional[torch.Tensor] = None + + # varlen args + query_start_loc: Optional[torch.Tensor] = None + seq_start_loc: Optional[torch.Tensor] = None + max_query_len: Optional[int] = None + max_seq_len: Optional[int] = None + + +class LlamaSwiftKVAttention(nn.Module): + + def __init__( + self, + config: LlamaSwiftKVConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_proj_swiftkv = ColumnParallelLinear( + input_size=hidden_size, + output_size=self.total_num_heads * self.head_dim, + bias=bias, + gather_output=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj_swiftkv", + ) + self.kv_proj_swiftkv = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=0, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj_swiftkv", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: SwiftKVMetadata, + ) -> torch.Tensor: + query, _ = self.q_proj_swiftkv(hidden_states) + query, _ = self.rotary_emb(positions, query, torch.empty_like(key)) + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_dim) + if (key is not None) and (value is not None): + key = key.view(-1, self.num_kv_heads, self.head_dim) + value = value.view(-1, self.num_kv_heads, self.head_dim) + + if attn_metadata.use_varlen: + if (kv_cache.numel() == 0 or attn_metadata.block_tables is None + or attn_metadata.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scaling, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + softcap=0, + ) + else: + # prefix-enabled attention + attn_output = flash_attn_varlen_func( # noqa + q=query, + k=kv_cache[0], + v=kv_cache[1], + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scaling, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + block_table=attn_metadata.block_tables, + softcap=0, + ) + else: + assert attn_metadata.seq_lens.numel() == num_tokens + if kv_cache.numel(): + attn_output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=kv_cache[0], + v_cache=kv_cache[1], + block_table=attn_metadata.block_tables, + cache_seqlens=attn_metadata.seq_lens, + softmax_scale=self.scaling, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + softcap=0, + ).squeeze(1) + else: + attn_output = flash_attn_func( + q=query.unsqueeze(1), + k=key.unsqueeze(1), + v=value.unsqueeze(1), + softmax_scale=self.scaling, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + softcap=0, + ).squeeze(1) + output = attn_output.view(num_tokens, hidden_size) + output, _ = self.o_proj(output) + return output + + +class LlamaSwiftKVDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaSwiftKVConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = LlamaSwiftKVAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + k_states: torch.Tensor, + v_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: SwiftKVMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + key=k_states, + value=v_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaSwiftKVModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.kv_cache_dtype = ( + cache_config.cache_dtype if cache_config is not None else "auto" + ) + + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.layers = torch.nn.ModuleList([ + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}") + if idx < config.num_key_value_layers + else LlamaSwiftKVDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}") + for idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Cuda graph inputs/output tensors + num_kv_heads = self.layers[0].self_attn.num_kv_heads + head_dim = self.layers[0].self_attn.head_dim + kv_size = num_kv_heads * head_dim + self.cuda_graphs = {} + self.cuda_graph_max_batch_size = 256 + self.cuda_graph_max_num_blocks = 2048 + self.cuda_graph_tensors = { + "positions": torch.empty(self.cuda_graph_max_batch_size, + dtype=torch.long), + "hidden_states": torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), + "residual": torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), + "kv_states": { + layer_idx: (torch.empty(self.cuda_graph_max_batch_size, kv_size), + torch.empty(self.cuda_graph_max_batch_size, kv_size)) + for layer_idx in range(config.num_key_value_layers, + config.num_hidden_layers) + }, + "metadata": SwiftKVMetadata( + use_varlen=False, + indices=None, + seq_lens=torch.empty(self.cuda_graph_max_batch_size, + dtype=torch.int32), + block_tables=torch.empty(self.cuda_graph_max_batch_size, + self.cuda_graph_max_num_blocks, + dtype=torch.int32), + ), + } + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def _get_swiftkv_metadata( + self, + attn_metadata: FlashAttentionMetadata, + sampling_metadata: Optional[SamplingMetadata], + ) -> SwiftKVMetadata: + sampling_indices = sampling_metadata.selected_token_indices.tolist() + swiftkv_indices = [] + swiftkv_seq_ids = [] + swiftkv_query_lens = [] + swiftkv_seq_lens = [] + idx = 0 + query_start_loc = attn_metadata.query_start_loc.tolist() + for seq_id in range(len(query_start_loc) - 1): + seq_begin = query_start_loc[seq_id] + seq_end = query_start_loc[seq_id + 1] + while (idx < len(sampling_indices) and + sampling_indices[idx] < seq_begin): + idx += 1 + if idx >= len(sampling_indices): + break + if sampling_indices[idx] < seq_end: + indices = list(range(sampling_indices[idx], seq_end)) + swiftkv_indices.extend(indices) + swiftkv_seq_ids.append(seq_id) + swiftkv_query_lens.append(len(indices)) + swiftkv_seq_lens.append(attn_metadata.seq_lens[seq_id]) + device = attn_metadata.query_start_loc.device + max_query_len = max(swiftkv_query_lens, default=0) + max_seq_len = max(swiftkv_seq_lens, default=0) + if max_query_len <= 1: + assert len(swiftkv_indices) == len(swiftkv_seq_ids) + return SwiftKVMetadata( + use_varlen=False, + indices=torch.tensor(swiftkv_indices, device=device), + block_tables=attn_metadata.block_tables[swiftkv_seq_ids], + seq_lens=torch.tensor(swiftkv_seq_lens, device=device, + dtype=torch.int32), + ) + else: + return SwiftKVMetadata( + use_varlen=True, + indices=torch.tensor(swiftkv_indices, device=device), + block_tables=attn_metadata.block_tables[swiftkv_seq_ids], + query_start_loc=torch.tensor( + [0] + swiftkv_query_lens, device=device, + ).cumsum(dim=0).to(torch.int32), + seq_start_loc=torch.tensor( + [0] + swiftkv_seq_lens, device=device, + ).cumsum(dim=0).to(torch.int32), + max_query_len=max_query_len, + max_seq_len=max_seq_len, + ) + + def _get_swiftkv_metadata_for_cuda_graph( + self, + attn_metadata: FlashAttentionMetadata, + ) -> SwiftKVMetadata: + assert (attn_metadata.num_prefills == 0 and + attn_metadata.max_decode_query_len == 1) + return SwiftKVMetadata( + use_varlen=False, + indices=None, + block_tables=attn_metadata.block_tables, + seq_lens=attn_metadata.seq_lens_tensor, + ) + + def _prepare_cuda_graph_inputs( + self, + size: int, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]], + swiftkv_metadata: SwiftKVMetadata, + ): + self.cuda_graph_tensors["positions"][:size].copy_(positions) + self.cuda_graph_tensors["hidden_states"][:size].copy_(hidden_states) + self.cuda_graph_tensors["residual"][:size].copy_(residual) + cuda_graph_kv_states = self.cuda_graph_tensors["kv_states"] + for layer_idx, (k, v) in kv_states.items(): + cuda_graph_kv_states[layer_idx][0][:size].copy_(k) + cuda_graph_kv_states[layer_idx][1][:size].copy_(v) + cuda_graph_metadata = self.cuda_graph_tensors["metadata"] + cuda_graph_metadata.seq_lens[:size].copy_(swiftkv_metadata.seq_lens) + num_blocks = min(self.cuda_graph_max_num_blocks, + swiftkv_metadata.block_tables.size(1)) + cuda_graph_metadata.block_tables[:size, :num_blocks].copy_( + swiftkv_metadata.block_tables[:, :num_blocks]) + # Pad to next power of 2 + padded_size = 1 << (size - 1).bit_length() + positions = self.cuda_graph_tensors["positions"][:padded_size] + hidden_states = self.cuda_graph_tensors["hidden_states"][:padded_size] + residual = self.cuda_graph_tensors["residual"][:padded_size] + for layer_idx in kv_states: + kv_states[layer_idx] = ( + cuda_graph_kv_states[layer_idx][0][:padded_size], + cuda_graph_kv_states[layer_idx][1][:padded_size], + ) + swiftkv_metadata = SwiftKVMetadata( + use_varlen=swiftkv_metadata.use_varlen, + indices=swiftkv_metadata.indices, + seq_lens=cuda_graph_metadata.seq_lens[:padded_size], + block_tables=cuda_graph_metadata.block_tables[:padded_size], + ) + return (padded_size, positions, hidden_states, residual, kv_states, + swiftkv_metadata) + + def _run_swiftkv_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], + swiftkv_metadata: SwiftKVMetadata, + ) -> torch.Tensor: + for layer_idx in range(self.config.num_key_value_layers, + self.config.num_hidden_layers): + layer = self.layers[layer_idx] + k_states, v_states = kv_states[layer_idx] + hidden_states, residual = layer( + positions, + hidden_states, + k_states, + v_states, + kv_caches[layer_idx], + swiftkv_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + sampling_metadata: Optional[SamplingMetadata] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + swiftkv_metadata = ( + self._get_swiftkv_metadata(attn_metadata, sampling_metadata) + if not attn_metadata.use_cuda_graph + else self._get_swiftkv_metadata_for_cuda_graph(attn_metadata) + ) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + for layer_idx in range(self.config.num_key_value_layers): + layer = self.layers[layer_idx] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[layer_idx], + attn_metadata, + residual, + ) + + # KV projection and cache of all the remaining layers + kv_states = {} + swiftkv_hidden_states = self.norm_swiftkv(hidden_states + residual) + for layer_idx in range(self.config.num_key_value_layers, + self.config.num_hidden_layers): + self_attn = self.layers[layer_idx].self_attn + kv, _ = self_attn.kv_proj_swiftkv(swiftkv_hidden_states) + k, v = kv.split(self_attn.kv_size, dim=-1) + q = torch.empty_like(hidden_states) # Just temporary buffer + _, k = self_attn.rotary_emb(positions, q, k) + kv_states[layer_idx] = (k, v) + if kv_caches[layer_idx].numel(): + torch.ops._C_cache_ops.reshape_and_cache_flash( + k.view(-1, self_attn.num_kv_heads, self_attn.head_dim), + v.view(-1, self_attn.num_kv_heads, self_attn.head_dim), + kv_caches[layer_idx][0], + kv_caches[layer_idx][1], + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + 1.0, 1.0, + ) + + if swiftkv_metadata.indices is not None: + if not swiftkv_metadata.indices.numel(): + return hidden_states # Early exit entire batch. + orig_hidden_states = hidden_states + hidden_states = hidden_states[swiftkv_metadata.indices] + residual = residual[swiftkv_metadata.indices] + positions = positions[swiftkv_metadata.indices] + kv_states = { + layer_idx: (k[swiftkv_metadata.indices], + v[swiftkv_metadata.indices]) + for layer_idx, (k, v) in kv_states.items() + } + + batch_size = hidden_states.size(0) + if (not attn_metadata.use_cuda_graph + and not swiftkv_metadata.use_varlen and kv_caches[0].numel() + and batch_size <= self.cuda_graph_max_batch_size + and swiftkv_metadata.block_tables.size(1) <= + self.cuda_graph_max_num_blocks + ): + # We implement our own (JIT-captured) cuda graph for the second + # half of the model (layers skipped for prefill tokens). + ( + padded_size, + positions, + hidden_states, + residual, + kv_states, + swiftkv_metadata, + ) = self._prepare_cuda_graph_inputs( + batch_size, + positions, + hidden_states, + residual, + kv_states, + swiftkv_metadata, + ) + g = self.cuda_graphs.get(padded_size) + cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] + if g is None: + print("JIT-capture SwiftKV CUDA graph for batch size", + padded_size) + with graph_capture() as capture_context: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_context.stream): + hidden_states = self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + cuda_graph_hidden_states[:padded_size].copy_( + hidden_states) + self.cuda_graphs[padded_size] = g + else: + g.replay() + hidden_states = cuda_graph_hidden_states[:batch_size] + else: + hidden_states = self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + if swiftkv_metadata.indices is None: + return hidden_states + orig_hidden_states[swiftkv_metadata.indices] = hidden_states + return orig_hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + for layer_idx in range(self.config.num_key_value_layers): + prefix = f".{layer_idx}.self_attn" + stacked_params_mapping.extend([ + (f"{prefix}.qkv_proj", f"{prefix}.q_proj", "q"), + (f"{prefix}.qkv_proj", f"{prefix}.k_proj", "k"), + (f"{prefix}.qkv_proj", f"{prefix}.v_proj", "v"), + ]) + for layer_idx in range(self.config.num_key_value_layers, + self.config.num_hidden_layers): + prefix = f".{layer_idx}.self_attn" + stacked_params_mapping.extend([ + (f"{prefix}.kv_proj_swiftkv", f"{prefix}.k_proj_swiftkv", "k"), + (f"{prefix}.kv_proj_swiftkv", f"{prefix}.v_proj_swiftkv", "v"), + ]) + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + orig_name = name + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + print(f"Skip loading {orig_name}") + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + print(f"Skip loading {orig_name}") + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if current_platform.is_rocm(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") + + +class LlamaSwiftKVForCausalLM(nn.Module): + packed_modules_mapping = { + "kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ".k_proj_swiftkv.", + ".v_proj_swiftkv.", + ] + + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [ + ".q_proj_swiftkv.", + ".down_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "k_proj_swiftkv": ("kv_proj_swiftkv", 1), + "v_proj_swiftkv": ("kv_proj_swiftkv", 2), + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = LlamaSwiftKVModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + sampling_metadata: Optional[SamplingMetadata] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + sampling_metadata=sampling_metadata) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights(weights) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..48b5592e363ba 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -59,6 +59,7 @@ "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "LlamaSwiftKVForCausalLM": ("llama_swiftkv", "LlamaSwiftKVForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 054845584c2ef..8abb36fe52e96 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -24,6 +24,7 @@ EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, + LlamaSwiftKVConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, @@ -52,6 +53,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, + "llama_swiftkv": LlamaSwiftKVConfig, "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d1e19c9a33c24..23e41e750413a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -9,6 +9,7 @@ from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.llama_swiftkv import LlamaSwiftKVConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig @@ -26,6 +27,7 @@ "H2OVLChatConfig", "InternVLChatConfig", "JAISConfig", + "LlamaSwiftKVConfig", "MedusaConfig", "EAGLEConfig", "ExaoneConfig", diff --git a/vllm/transformers_utils/configs/llama_swiftkv.py b/vllm/transformers_utils/configs/llama_swiftkv.py new file mode 100644 index 0000000000000..ff02e1d7ef71e --- /dev/null +++ b/vllm/transformers_utils/configs/llama_swiftkv.py @@ -0,0 +1,31 @@ +from typing import Optional + +from transformers import LlamaConfig + + +class LlamaSwiftKVConfig(LlamaConfig): + """ + Args: + num_key_value_layers (int, optional): + The number of layers, from the first layer, that have keys and + values. If None, all layers have keys and values. + last_key_value_heads (int, optional): + The number of heads in the last layer that have keys and values. + If None, the number of heads in the last key-value layer is equal + to the number of heads in all the other key-value layers. + """ + + model_type = "llama_swiftkv" + + def __init__( + self, + swiftkv: bool = False, + num_key_value_layers: Optional[int] = None, + key_value_group_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.swiftkv = swiftkv + self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers + self.key_value_group_size = key_value_group_size or 1 + assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 042f9f07eace6..7aa3f7687fde3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1650,6 +1650,9 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + swiftkv_kwargs = ({"sampling_metadata": model_input.sampling_metadata} + if "SwiftKV" in type(self.model).__name__ else {}) + with set_forward_context(model_input.attn_metadata): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, @@ -1657,6 +1660,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + **swiftkv_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) From e01dcce0dd0a046555f44ab98cd8881cf67d9aef Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 3 Dec 2024 21:05:28 +0000 Subject: [PATCH 02/12] cuda graph config --- vllm/model_executor/models/llama_swiftkv.py | 85 ++++++++++++--------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/llama_swiftkv.py b/vllm/model_executor/models/llama_swiftkv.py index f19a0ae91b828..1dbb2c84fb6d9 100644 --- a/vllm/model_executor/models/llama_swiftkv.py +++ b/vllm/model_executor/models/llama_swiftkv.py @@ -327,6 +327,13 @@ def forward( return hidden_states, residual +def _padded_size(size: int) -> int: + mult = (1 << (size - 1).bit_length()) // 4 + if mult < 1: + return size + return (size + mult - 1) // mult * mult + + class LlamaSwiftKVModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -368,35 +375,44 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Cuda graph inputs/output tensors - num_kv_heads = self.layers[0].self_attn.num_kv_heads - head_dim = self.layers[0].self_attn.head_dim - kv_size = num_kv_heads * head_dim - self.cuda_graphs = {} - self.cuda_graph_max_batch_size = 256 - self.cuda_graph_max_num_blocks = 2048 - self.cuda_graph_tensors = { - "positions": torch.empty(self.cuda_graph_max_batch_size, - dtype=torch.long), - "hidden_states": torch.empty(self.cuda_graph_max_batch_size, - config.hidden_size), - "residual": torch.empty(self.cuda_graph_max_batch_size, - config.hidden_size), - "kv_states": { - layer_idx: (torch.empty(self.cuda_graph_max_batch_size, kv_size), - torch.empty(self.cuda_graph_max_batch_size, kv_size)) - for layer_idx in range(config.num_key_value_layers, - config.num_hidden_layers) - }, - "metadata": SwiftKVMetadata( - use_varlen=False, - indices=None, - seq_lens=torch.empty(self.cuda_graph_max_batch_size, - dtype=torch.int32), - block_tables=torch.empty(self.cuda_graph_max_batch_size, - self.cuda_graph_max_num_blocks, - dtype=torch.int32), - ), - } + if not vllm_config.model_config.enforce_eager: + self.use_inner_cuda_graph = True + num_kv_heads = self.layers[0].self_attn.num_kv_heads + head_dim = self.layers[0].self_attn.head_dim + kv_size = num_kv_heads * head_dim + self.cuda_graphs = {} + self.cuda_graph_max_batch_size = _padded_size( + vllm_config.scheduler_config.max_num_seqs) + self.cuda_graph_max_num_blocks = ( + vllm_config.model_config.max_seq_len_to_capture // + vllm_config.cache_config.block_size) + self.cuda_graph_tensors = { + "positions": torch.empty(self.cuda_graph_max_batch_size, + dtype=torch.long), + "hidden_states": torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), + "residual": torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), + "kv_states": { + layer_idx: ( + torch.empty(self.cuda_graph_max_batch_size, kv_size), + torch.empty(self.cuda_graph_max_batch_size, kv_size), + ) + for layer_idx in range(config.num_key_value_layers, + config.num_hidden_layers) + }, + "metadata": SwiftKVMetadata( + use_varlen=False, + indices=None, + seq_lens=torch.empty(self.cuda_graph_max_batch_size, + dtype=torch.int32), + block_tables=torch.empty(self.cuda_graph_max_batch_size, + self.cuda_graph_max_num_blocks, + dtype=torch.int32), + ), + } + else: + self.use_inner_cuda_graph = False def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -489,8 +505,8 @@ def _prepare_cuda_graph_inputs( swiftkv_metadata.block_tables.size(1)) cuda_graph_metadata.block_tables[:size, :num_blocks].copy_( swiftkv_metadata.block_tables[:, :num_blocks]) - # Pad to next power of 2 - padded_size = 1 << (size - 1).bit_length() + # Pad to next highest cuda graph batch size + padded_size = _padded_size(size) positions = self.cuda_graph_tensors["positions"][:padded_size] hidden_states = self.cuda_graph_tensors["hidden_states"][:padded_size] residual = self.cuda_graph_tensors["residual"][:padded_size] @@ -601,13 +617,13 @@ def forward( } batch_size = hidden_states.size(0) - if (not attn_metadata.use_cuda_graph + if (self.use_inner_cuda_graph and not attn_metadata.use_cuda_graph and not swiftkv_metadata.use_varlen and kv_caches[0].numel() and batch_size <= self.cuda_graph_max_batch_size and swiftkv_metadata.block_tables.size(1) <= self.cuda_graph_max_num_blocks ): - # We implement our own (JIT-captured) cuda graph for the second + # We implement our own (just-in-time) cuda graph for the second # half of the model (layers skipped for prefill tokens). ( padded_size, @@ -627,8 +643,7 @@ def forward( g = self.cuda_graphs.get(padded_size) cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] if g is None: - print("JIT-capture SwiftKV CUDA graph for batch size", - padded_size) + print("Capture SwiftKV CUDA graph for batch size", padded_size) with graph_capture() as capture_context: g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=capture_context.stream): From 9087a7756db8f34b87748337063979a9d77e41e9 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 4 Dec 2024 13:53:54 -0800 Subject: [PATCH 03/12] Create README.md --- swiftkv/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 swiftkv/README.md diff --git a/swiftkv/README.md b/swiftkv/README.md new file mode 100644 index 0000000000000..bf5f428a3c577 --- /dev/null +++ b/swiftkv/README.md @@ -0,0 +1 @@ +# SwiftKV From 21eaea8674315fdec4ddd762414f6a5be8b8f8dc Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 4 Dec 2024 13:57:01 -0800 Subject: [PATCH 04/12] swiftkv readme --- examples/swiftkv/README.md | 5 +++++ swiftkv/README.md | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 examples/swiftkv/README.md delete mode 100644 swiftkv/README.md diff --git a/examples/swiftkv/README.md b/examples/swiftkv/README.md new file mode 100644 index 0000000000000..4a7ec1877553d --- /dev/null +++ b/examples/swiftkv/README.md @@ -0,0 +1,5 @@ +# SwiftKV + +## Evaluation + +## Performance Benchmarks \ No newline at end of file diff --git a/swiftkv/README.md b/swiftkv/README.md deleted file mode 100644 index bf5f428a3c577..0000000000000 --- a/swiftkv/README.md +++ /dev/null @@ -1 +0,0 @@ -# SwiftKV From 31f8eb7d5ee5b662b84bb1a8187c252545b43af2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 5 Dec 2024 01:35:13 +0000 Subject: [PATCH 05/12] update swiftkv examples --- examples/swiftkv/README.md | 83 ++++++++++++++++++- examples/swiftkv/offline_inference_swiftkv.py | 26 ++++++ examples/swiftkv/run_eval_405b_fp8.sh | 45 ++++++++++ examples/swiftkv/run_eval_8b.sh | 45 ++++++++++ 4 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 examples/swiftkv/offline_inference_swiftkv.py create mode 100644 examples/swiftkv/run_eval_405b_fp8.sh create mode 100644 examples/swiftkv/run_eval_8b.sh diff --git a/examples/swiftkv/README.md b/examples/swiftkv/README.md index 4a7ec1877553d..6880f199ef074 100644 --- a/examples/swiftkv/README.md +++ b/examples/swiftkv/README.md @@ -1,5 +1,84 @@ # SwiftKV -## Evaluation +SwiftKV is a technique developed by Snowflake AI Research that reduces computational overhead during prompt processing by combining model rewiring and knowledge-preserving self-distillation. -## Performance Benchmarks \ No newline at end of file +For more details, see: + +- [Blog post](https://www.snowflake.com/engineering-blog/swiftkv-llm-compute-reduction) +- [Paper](https://arxiv.org/abs/2410.03960) +- [Huggingface](https://huggingface.co/collections/Snowflake/swiftkv-models-674f7d7474eb789e185d31cb) + +## Quickstart + +Install vLLM from Snowflake-Labs: +```console +$ pip install git+https://github.com/snowflake-labs/vllm.git@swiftkv +``` + +Run an example conversation using [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct): +```console +$ python examples/swiftkv/offline_inference_swiftkv.py + +... + +The Importance of Higher Education + +Higher education is a vital component of an individual's life, providing numerous benefits that extend beyond the acquisition of knowledge and skills. It plays a significant role in shaping an individual's future, career prospects, and overall well-being. In this essay, we will explore the importance of higher education and its far-reaching implications on individuals, society, and the economy. + +... +``` + +## Running Accuracy Evaluations + +To evaluate the Llama-3.1-SwiftKV models, we use the [LM-Eval fork by NeuralMagic](https://github.com/neuralmagic/lm-evaluation-harness.git): + +```console +$ pip install git+https://github.com/neuralmagic/lm-evaluation-harness.git@llama_3.1_instruct +``` + +Run evaluation on Llama-3.1-SwiftKV-8B-Instruct: + +```console +$ bash examples/swiftkv/run_eval_8b.sh +``` + +Run evaluation on Llama-3.1-SwiftKV-405B-Instruct-FP8: + +```console +$ bash examples/swiftkv/run_eval_405b_fp8.sh +``` + +## Running Performance Benchmarks + +Llama-3.1-SwiftKV-8B-Instruct + +```console +$ python benchmarks/benchmark_throughput.py \ + --input-len 2000 --output-len 256 \ + --model Snowflake/Llama-3.1-SwiftKV-8B-Instruct \ + --gpu-memory-utilization 0.95 \ + --enable-chunked-prefill \ + --max-num-batched-tokens 2048 \ + --max-num-seqs 512 + +... + +Throughput: 11.36 requests/s, 25635.51 total tokens/s, 2908.99 output tokens/s +``` + +Llama-3.1-SwiftKV-405B-Instruct-FP8 + +```console +$ python benchmarks/benchmark_throughput.py \ + --input-len 2000 --output-len 256 \ + --model Snowflake/Llama-3.1-SwiftKV-405B-Instruct-FP8 \ + --gpu-memory-utilization 0.95 \ + --enable-chunked-prefill \ + --max-num-batched-tokens 2048 \ + --max-num-seqs 512 \ + --tensor-parallel-size 8 + +... + +Throughput: 3.21 requests/s, 7233.37 total tokens/s, 820.81 output tokens/s +``` diff --git a/examples/swiftkv/offline_inference_swiftkv.py b/examples/swiftkv/offline_inference_swiftkv.py new file mode 100644 index 0000000000000..f496675eabec0 --- /dev/null +++ b/examples/swiftkv/offline_inference_swiftkv.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +llm = LLM(model="Snowflake/Llama-3.1-SwiftKV-8B-Instruct", enforce_eager=True) + +print("=" * 80) + +conversation = [ + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] + +sampling_params = SamplingParams(temperature=0.1, max_tokens=800) + +outputs = llm.chat(conversation, sampling_params=sampling_params) + +print(outputs[0].outputs[0].text) diff --git a/examples/swiftkv/run_eval_405b_fp8.sh b/examples/swiftkv/run_eval_405b_fp8.sh new file mode 100644 index 0000000000000..5f24ea13fd6e1 --- /dev/null +++ b/examples/swiftkv/run_eval_405b_fp8.sh @@ -0,0 +1,45 @@ +MODEL=Snowflake/Llama-3.1-SwiftKV-405B-Instruct-FP8 + +EVAL_CMD=$(cat < Date: Thu, 5 Dec 2024 04:09:11 +0000 Subject: [PATCH 06/12] fix cuda graph --- examples/swiftkv/offline_inference_swiftkv.py | 2 +- vllm/model_executor/models/llama_swiftkv.py | 37 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/examples/swiftkv/offline_inference_swiftkv.py b/examples/swiftkv/offline_inference_swiftkv.py index f496675eabec0..cf29671ad90c1 100644 --- a/examples/swiftkv/offline_inference_swiftkv.py +++ b/examples/swiftkv/offline_inference_swiftkv.py @@ -1,6 +1,6 @@ from vllm import LLM, SamplingParams -llm = LLM(model="Snowflake/Llama-3.1-SwiftKV-8B-Instruct", enforce_eager=True) +llm = LLM(model="Snowflake/Llama-3.1-SwiftKV-8B-Instruct") print("=" * 80) diff --git a/vllm/model_executor/models/llama_swiftkv.py b/vllm/model_executor/models/llama_swiftkv.py index 1dbb2c84fb6d9..ec5fd64663752 100644 --- a/vllm/model_executor/models/llama_swiftkv.py +++ b/vllm/model_executor/models/llama_swiftkv.py @@ -643,20 +643,31 @@ def forward( g = self.cuda_graphs.get(padded_size) cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] if g is None: + g = torch.cuda.CUDAGraph() + # Run a few times first to ensure the captured graph does not + # include kernel launches for initial benchmarking (e.g., Triton + # autotune). Note that once is not enough for torch.jit.script. + for _ in range(2): + h = self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + cuda_graph_hidden_states[:padded_size].copy_(h) print("Capture SwiftKV CUDA graph for batch size", padded_size) - with graph_capture() as capture_context: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, stream=capture_context.stream): - hidden_states = self._run_swiftkv_layers( - positions, - hidden_states, - residual, - kv_states, - kv_caches, - swiftkv_metadata, - ) - cuda_graph_hidden_states[:padded_size].copy_( - hidden_states) + with graph_capture() as c, torch.cuda.graph(g, stream=c.stream): + hidden_states = self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + cuda_graph_hidden_states[:padded_size].copy_(hidden_states) self.cuda_graphs[padded_size] = g else: g.replay() From 4f3d05a9215746d122e3d15e65a0c07e8deb9ad0 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 4 Dec 2024 23:24:33 -0500 Subject: [PATCH 07/12] Update README.md --- examples/swiftkv/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/swiftkv/README.md b/examples/swiftkv/README.md index 6880f199ef074..d39cd164b579a 100644 --- a/examples/swiftkv/README.md +++ b/examples/swiftkv/README.md @@ -1,4 +1,4 @@ -# SwiftKV +# SwiftKV on vLLM SwiftKV is a technique developed by Snowflake AI Research that reduces computational overhead during prompt processing by combining model rewiring and knowledge-preserving self-distillation. From bb43a8bd97ab247a65f82b3022a79aa365c8b29d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sat, 7 Dec 2024 17:14:13 -0500 Subject: [PATCH 08/12] add test and fix a few bugs --- tests/swiftkv/__init__.py | 0 tests/swiftkv/test_llama_fp8.py | 43 +++++ vllm/model_executor/models/llama_swiftkv.py | 200 +++++++++++--------- 3 files changed, 149 insertions(+), 94 deletions(-) create mode 100644 tests/swiftkv/__init__.py create mode 100644 tests/swiftkv/test_llama_fp8.py diff --git a/tests/swiftkv/__init__.py b/tests/swiftkv/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/swiftkv/test_llama_fp8.py b/tests/swiftkv/test_llama_fp8.py new file mode 100644 index 0000000000000..84d5463d50219 --- /dev/null +++ b/tests/swiftkv/test_llama_fp8.py @@ -0,0 +1,43 @@ +import pytest + +import vllm +from tests.utils import multi_gpu_test +from vllm.sampling_params import SamplingParams + +MODELS = ["Snowflake/Llama-3.1-SwiftKV-8B-Instruct-FP8"] +CONVERSATIONS = [ + [{"role": "user", "content": "Hello!"}], + [{"role": "user", "content": "Who is the president of the United States?"}], + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "What is the future of AI?"}], +] +EXPECTED_OUTPUTS = [ + "Hello! How can I assist you today?", + "As of my cut-off knowledge in December 2023, the President of the United " + "States is Joe", + "The capital of France is Paris.", + "The future of AI is vast and rapidly evolving, with numerous potential " + "developments and applications on the horizon.", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +@multi_gpu_test(num_gpus=2) +def test_model(model, enforce_eager, tensor_parallel_size) -> None: + llm = vllm.LLM( + model, + enforce_eager=enforce_eager, + enable_chunked_prefill=True, + tensor_parallel_size=tensor_parallel_size, + ) + sampling_params = SamplingParams(temperature=0.0, max_tokens=20) + + for idx, conversation in enumerate(CONVERSATIONS): + outputs = llm.chat( + conversation, + sampling_params=sampling_params, + use_tqdm=False, + ) + assert outputs[0].outputs[0].text == EXPECTED_OUTPUTS[idx] diff --git a/vllm/model_executor/models/llama_swiftkv.py b/vllm/model_executor/models/llama_swiftkv.py index ec5fd64663752..1c57b91ab3402 100644 --- a/vllm/model_executor/models/llama_swiftkv.py +++ b/vllm/model_executor/models/llama_swiftkv.py @@ -58,6 +58,7 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, is_pp_missing_parameter, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import LlamaSwiftKVConfig @@ -177,45 +178,27 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_dim) if attn_metadata.use_varlen: - if (kv_cache.numel() == 0 or attn_metadata.block_tables is None - or attn_metadata.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - attn_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_seq_len, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=self.scaling, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - softcap=0, - ) - else: - # prefix-enabled attention - attn_output = flash_attn_varlen_func( # noqa - q=query, - k=kv_cache[0], - v=kv_cache[1], - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=self.scaling, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - block_table=attn_metadata.block_tables, - softcap=0, - ) + # Should be neither capture nor profile run. + assert kv_cache.numel() and attn_metadata.block_tables.numel() + attn_output = flash_attn_varlen_func( # noqa + q=query, + k=kv_cache[0], + v=kv_cache[1], + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scaling, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + block_table=attn_metadata.block_tables, + softcap=0, + ) else: assert attn_metadata.seq_lens.numel() == num_tokens if kv_cache.numel(): + assert attn_metadata.block_tables.numel() attn_output = flash_attn_with_kvcache( q=query.unsqueeze(1), k_cache=kv_cache[0], @@ -229,6 +212,8 @@ def forward( softcap=0, ).squeeze(1) else: + # For profile run, we don't have kv_cache and block_tables. + assert not attn_metadata.block_tables.numel() attn_output = flash_attn_func( q=query.unsqueeze(1), k=key.unsqueeze(1), @@ -337,6 +322,9 @@ def _padded_size(size: int) -> int: class LlamaSwiftKVModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + if not vllm_config.scheduler_config.chunked_prefill_enabled: + raise ValueError("SwiftKV requires chunked prefill to be enabled") + super().__init__() config = vllm_config.model_config.hf_config @@ -383,9 +371,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.cuda_graphs = {} self.cuda_graph_max_batch_size = _padded_size( vllm_config.scheduler_config.max_num_seqs) + max_seq_len = vllm_config.model_config.max_seq_len_to_capture + block_size = vllm_config.cache_config.block_size self.cuda_graph_max_num_blocks = ( - vllm_config.model_config.max_seq_len_to_capture // - vllm_config.cache_config.block_size) + (max_seq_len + block_size - 1) // block_size) self.cuda_graph_tensors = { "positions": torch.empty(self.cuda_graph_max_batch_size, dtype=torch.long), @@ -411,6 +400,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): dtype=torch.int32), ), } + self.cuda_graph_pool = None else: self.use_inner_cuda_graph = False @@ -483,22 +473,21 @@ def _get_swiftkv_metadata_for_cuda_graph( seq_lens=attn_metadata.seq_lens_tensor, ) - def _prepare_cuda_graph_inputs( + def _prepare_cuda_graph( self, - size: int, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor, kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]], swiftkv_metadata: SwiftKVMetadata, ): + size = hidden_states.size(0) self.cuda_graph_tensors["positions"][:size].copy_(positions) self.cuda_graph_tensors["hidden_states"][:size].copy_(hidden_states) self.cuda_graph_tensors["residual"][:size].copy_(residual) - cuda_graph_kv_states = self.cuda_graph_tensors["kv_states"] - for layer_idx, (k, v) in kv_states.items(): - cuda_graph_kv_states[layer_idx][0][:size].copy_(k) - cuda_graph_kv_states[layer_idx][1][:size].copy_(v) + for idx, (k, v) in kv_states.items(): + self.cuda_graph_tensors["kv_states"][idx][0][:size].copy_(k) + self.cuda_graph_tensors["kv_states"][idx][1][:size].copy_(v) cuda_graph_metadata = self.cuda_graph_tensors["metadata"] cuda_graph_metadata.seq_lens[:size].copy_(swiftkv_metadata.seq_lens) num_blocks = min(self.cuda_graph_max_num_blocks, @@ -510,19 +499,17 @@ def _prepare_cuda_graph_inputs( positions = self.cuda_graph_tensors["positions"][:padded_size] hidden_states = self.cuda_graph_tensors["hidden_states"][:padded_size] residual = self.cuda_graph_tensors["residual"][:padded_size] - for layer_idx in kv_states: - kv_states[layer_idx] = ( - cuda_graph_kv_states[layer_idx][0][:padded_size], - cuda_graph_kv_states[layer_idx][1][:padded_size], - ) + kv_states = { + idx: (k[:padded_size], v[:padded_size]) + for idx, (k, v) in self.cuda_graph_tensors["kv_states"].items() + } swiftkv_metadata = SwiftKVMetadata( use_varlen=swiftkv_metadata.use_varlen, indices=swiftkv_metadata.indices, seq_lens=cuda_graph_metadata.seq_lens[:padded_size], block_tables=cuda_graph_metadata.block_tables[:padded_size], ) - return (padded_size, positions, hidden_states, residual, kv_states, - swiftkv_metadata) + return positions, hidden_states, residual, kv_states, swiftkv_metadata def _run_swiftkv_layers( self, @@ -549,6 +536,57 @@ def _run_swiftkv_layers( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def _capture_cuda_graph( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], + swiftkv_metadata: SwiftKVMetadata, + ) -> torch.cuda.graph: + positions, hidden_states, residual, kv_states, swiftkv_metadata = ( + self._prepare_cuda_graph( + positions, + hidden_states, + residual, + kv_states, + swiftkv_metadata, + ) + ) + padded_size = _padded_size(hidden_states.size(0)) + cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] + with graph_capture() as ctx, torch.cuda.stream(ctx.stream): + graph = torch.cuda.CUDAGraph() + # Run a few times first to ensure the captured graph does not + # include kernel launches for initial benchmarking (e.g., Triton + # autotune). Note that once is not enough for torch.jit.script. + for _ in range(2): + cuda_graph_hidden_states[:padded_size].copy_( + self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + ) + ctx.stream.synchronize() + with torch.cuda.graph(graph, stream=ctx.stream): + cuda_graph_hidden_states[:padded_size].copy_( + self._run_swiftkv_layers( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + ) + self.cuda_graph_pool = graph.pool() + return graph + def forward( self, input_ids: Optional[torch.Tensor], @@ -616,62 +654,36 @@ def forward( for layer_idx, (k, v) in kv_states.items() } - batch_size = hidden_states.size(0) + size = hidden_states.size(0) if (self.use_inner_cuda_graph and not attn_metadata.use_cuda_graph and not swiftkv_metadata.use_varlen and kv_caches[0].numel() - and batch_size <= self.cuda_graph_max_batch_size + and size <= self.cuda_graph_max_batch_size + and swiftkv_metadata.block_tables.numel() and swiftkv_metadata.block_tables.size(1) <= self.cuda_graph_max_num_blocks ): # We implement our own (just-in-time) cuda graph for the second # half of the model (layers skipped for prefill tokens). - ( - padded_size, - positions, - hidden_states, - residual, - kv_states, - swiftkv_metadata, - ) = self._prepare_cuda_graph_inputs( - batch_size, + padded_size = _padded_size(size) + if padded_size not in self.cuda_graphs: + print("Capture SwiftKV CUDA graph for batch size", padded_size) + self.cuda_graphs[padded_size] = self._capture_cuda_graph( + positions, + hidden_states, + residual, + kv_states, + kv_caches, + swiftkv_metadata, + ) + self._prepare_cuda_graph( positions, hidden_states, residual, kv_states, swiftkv_metadata, ) - g = self.cuda_graphs.get(padded_size) - cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] - if g is None: - g = torch.cuda.CUDAGraph() - # Run a few times first to ensure the captured graph does not - # include kernel launches for initial benchmarking (e.g., Triton - # autotune). Note that once is not enough for torch.jit.script. - for _ in range(2): - h = self._run_swiftkv_layers( - positions, - hidden_states, - residual, - kv_states, - kv_caches, - swiftkv_metadata, - ) - cuda_graph_hidden_states[:padded_size].copy_(h) - print("Capture SwiftKV CUDA graph for batch size", padded_size) - with graph_capture() as c, torch.cuda.graph(g, stream=c.stream): - hidden_states = self._run_swiftkv_layers( - positions, - hidden_states, - residual, - kv_states, - kv_caches, - swiftkv_metadata, - ) - cuda_graph_hidden_states[:padded_size].copy_(hidden_states) - self.cuda_graphs[padded_size] = g - else: - g.replay() - hidden_states = cuda_graph_hidden_states[:batch_size] + self.cuda_graphs[padded_size].replay() + hidden_states.copy_(self.cuda_graph_tensors["hidden_states"][:size]) else: hidden_states = self._run_swiftkv_layers( positions, From 95a264d181e2bbc57a77e18563ec182b5d5c6d56 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 9 Dec 2024 11:27:48 -0500 Subject: [PATCH 09/12] lint --- examples/swiftkv/run_eval_405b_fp8.sh | 2 + examples/swiftkv/run_eval_8b.sh | 2 + tests/swiftkv/test_llama_fp8.py | 20 ++- vllm/model_executor/models/llama_swiftkv.py | 146 ++++++++---------- .../configs/llama_swiftkv.py | 10 +- vllm/worker/model_runner.py | 5 +- 6 files changed, 89 insertions(+), 96 deletions(-) diff --git a/examples/swiftkv/run_eval_405b_fp8.sh b/examples/swiftkv/run_eval_405b_fp8.sh index 5f24ea13fd6e1..3f87b2fe682d4 100644 --- a/examples/swiftkv/run_eval_405b_fp8.sh +++ b/examples/swiftkv/run_eval_405b_fp8.sh @@ -1,3 +1,5 @@ +#/usr/bin/env bash + MODEL=Snowflake/Llama-3.1-SwiftKV-405B-Instruct-FP8 EVAL_CMD=$(cat < int: mult = (1 << (size - 1).bit_length()) // 4 if mult < 1: return size - return (size + mult - 1) // mult * mult + return (size + mult - 1) // mult * mult class LlamaSwiftKVModel(nn.Module): @@ -331,9 +308,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - self.kv_cache_dtype = ( - cache_config.cache_dtype if cache_config is not None else "auto" - ) + self.kv_cache_dtype = (cache_config.cache_dtype + if cache_config is not None else "auto") self.config = config self.padding_idx = config.pad_token_id @@ -352,15 +328,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers.{idx}") - if idx < config.num_key_value_layers - else LlamaSwiftKVDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{idx}") + if idx < config.num_key_value_layers else LlamaSwiftKVDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}") for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) # Cuda graph inputs/output tensors if not vllm_config.model_config.enforce_eager: @@ -373,31 +350,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.scheduler_config.max_num_seqs) max_seq_len = vllm_config.model_config.max_seq_len_to_capture block_size = vllm_config.cache_config.block_size - self.cuda_graph_max_num_blocks = ( - (max_seq_len + block_size - 1) // block_size) + self.cuda_graph_max_num_blocks = ((max_seq_len + block_size - 1) // + block_size) self.cuda_graph_tensors = { - "positions": torch.empty(self.cuda_graph_max_batch_size, - dtype=torch.long), - "hidden_states": torch.empty(self.cuda_graph_max_batch_size, - config.hidden_size), - "residual": torch.empty(self.cuda_graph_max_batch_size, - config.hidden_size), + "positions": + torch.empty(self.cuda_graph_max_batch_size, dtype=torch.long), + "hidden_states": + torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), + "residual": + torch.empty(self.cuda_graph_max_batch_size, + config.hidden_size), "kv_states": { layer_idx: ( torch.empty(self.cuda_graph_max_batch_size, kv_size), torch.empty(self.cuda_graph_max_batch_size, kv_size), ) for layer_idx in range(config.num_key_value_layers, - config.num_hidden_layers) + config.num_hidden_layers) }, - "metadata": SwiftKVMetadata( + "metadata": + SwiftKVMetadata( use_varlen=False, indices=None, seq_lens=torch.empty(self.cuda_graph_max_batch_size, - dtype=torch.int32), + dtype=torch.int32), block_tables=torch.empty(self.cuda_graph_max_batch_size, - self.cuda_graph_max_num_blocks, - dtype=torch.int32), + self.cuda_graph_max_num_blocks, + dtype=torch.int32), ), } self.cuda_graph_pool = None @@ -422,8 +402,8 @@ def _get_swiftkv_metadata( for seq_id in range(len(query_start_loc) - 1): seq_begin = query_start_loc[seq_id] seq_end = query_start_loc[seq_id + 1] - while (idx < len(sampling_indices) and - sampling_indices[idx] < seq_begin): + while (idx < len(sampling_indices) + and sampling_indices[idx] < seq_begin): idx += 1 if idx >= len(sampling_indices): break @@ -442,8 +422,9 @@ def _get_swiftkv_metadata( use_varlen=False, indices=torch.tensor(swiftkv_indices, device=device), block_tables=attn_metadata.block_tables[swiftkv_seq_ids], - seq_lens=torch.tensor(swiftkv_seq_lens, device=device, - dtype=torch.int32), + seq_lens=torch.tensor(swiftkv_seq_lens, + device=device, + dtype=torch.int32), ) else: return SwiftKVMetadata( @@ -451,10 +432,12 @@ def _get_swiftkv_metadata( indices=torch.tensor(swiftkv_indices, device=device), block_tables=attn_metadata.block_tables[swiftkv_seq_ids], query_start_loc=torch.tensor( - [0] + swiftkv_query_lens, device=device, + [0] + swiftkv_query_lens, + device=device, ).cumsum(dim=0).to(torch.int32), seq_start_loc=torch.tensor( - [0] + swiftkv_seq_lens, device=device, + [0] + swiftkv_seq_lens, + device=device, ).cumsum(dim=0).to(torch.int32), max_query_len=max_query_len, max_seq_len=max_seq_len, @@ -464,8 +447,8 @@ def _get_swiftkv_metadata_for_cuda_graph( self, attn_metadata: FlashAttentionMetadata, ) -> SwiftKVMetadata: - assert (attn_metadata.num_prefills == 0 and - attn_metadata.max_decode_query_len == 1) + assert (attn_metadata.num_prefills == 0 + and attn_metadata.max_decode_query_len == 1) return SwiftKVMetadata( use_varlen=False, indices=None, @@ -552,8 +535,7 @@ def _capture_cuda_graph( residual, kv_states, swiftkv_metadata, - ) - ) + )) padded_size = _padded_size(hidden_states.size(0)) cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"] with graph_capture() as ctx, torch.cuda.stream(ctx.stream): @@ -570,8 +552,7 @@ def _capture_cuda_graph( kv_states, kv_caches, swiftkv_metadata, - ) - ) + )) ctx.stream.synchronize() with torch.cuda.graph(graph, stream=ctx.stream): cuda_graph_hidden_states[:padded_size].copy_( @@ -582,8 +563,7 @@ def _capture_cuda_graph( kv_states, kv_caches, swiftkv_metadata, - ) - ) + )) self.cuda_graph_pool = graph.pool() return graph @@ -599,9 +579,8 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: swiftkv_metadata = ( self._get_swiftkv_metadata(attn_metadata, sampling_metadata) - if not attn_metadata.use_cuda_graph - else self._get_swiftkv_metadata_for_cuda_graph(attn_metadata) - ) + if not attn_metadata.use_cuda_graph else + self._get_swiftkv_metadata_for_cuda_graph(attn_metadata)) if inputs_embeds is not None: hidden_states = inputs_embeds @@ -638,7 +617,8 @@ def forward( kv_caches[layer_idx][1], attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - 1.0, 1.0, + 1.0, + 1.0, ) if swiftkv_metadata.indices is not None: @@ -649,19 +629,18 @@ def forward( residual = residual[swiftkv_metadata.indices] positions = positions[swiftkv_metadata.indices] kv_states = { - layer_idx: (k[swiftkv_metadata.indices], - v[swiftkv_metadata.indices]) + layer_idx: + (k[swiftkv_metadata.indices], v[swiftkv_metadata.indices]) for layer_idx, (k, v) in kv_states.items() } size = hidden_states.size(0) if (self.use_inner_cuda_graph and not attn_metadata.use_cuda_graph - and not swiftkv_metadata.use_varlen and kv_caches[0].numel() - and size <= self.cuda_graph_max_batch_size - and swiftkv_metadata.block_tables.numel() - and swiftkv_metadata.block_tables.size(1) <= - self.cuda_graph_max_num_blocks - ): + and not swiftkv_metadata.use_varlen and kv_caches[0].numel() + and size <= self.cuda_graph_max_batch_size + and swiftkv_metadata.block_tables.numel() + and swiftkv_metadata.block_tables.size(1) <= + self.cuda_graph_max_num_blocks): # We implement our own (just-in-time) cuda graph for the second # half of the model (layers skipped for prefill tokens). padded_size = _padded_size(size) @@ -683,7 +662,8 @@ def forward( swiftkv_metadata, ) self.cuda_graphs[padded_size].replay() - hidden_states.copy_(self.cuda_graph_tensors["hidden_states"][:size]) + hidden_states.copy_( + self.cuda_graph_tensors["hidden_states"][:size]) else: hidden_states = self._run_swiftkv_layers( positions, @@ -871,8 +851,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + config.vocab_size, logit_scale) self.sampler = Sampler() def forward( @@ -884,8 +863,11 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, sampling_metadata: Optional[SamplingMetadata] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, sampling_metadata=sampling_metadata) return model_output diff --git a/vllm/transformers_utils/configs/llama_swiftkv.py b/vllm/transformers_utils/configs/llama_swiftkv.py index ff02e1d7ef71e..5290e2d1a05ec 100644 --- a/vllm/transformers_utils/configs/llama_swiftkv.py +++ b/vllm/transformers_utils/configs/llama_swiftkv.py @@ -9,10 +9,6 @@ class LlamaSwiftKVConfig(LlamaConfig): num_key_value_layers (int, optional): The number of layers, from the first layer, that have keys and values. If None, all layers have keys and values. - last_key_value_heads (int, optional): - The number of heads in the last layer that have keys and values. - If None, the number of heads in the last key-value layer is equal - to the number of heads in all the other key-value layers. """ model_type = "llama_swiftkv" @@ -21,11 +17,9 @@ def __init__( self, swiftkv: bool = False, num_key_value_layers: Optional[int] = None, - key_value_group_size: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self.swiftkv = swiftkv - self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers - self.key_value_group_size = key_value_group_size or 1 - assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 + self.num_key_value_layers = (num_key_value_layers or + self.num_hidden_layers) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 381f1e67712df..ea54b3adfd265 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1673,8 +1673,9 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - swiftkv_kwargs = ({"sampling_metadata": model_input.sampling_metadata} - if "SwiftKV" in type(self.model).__name__ else {}) + swiftkv_kwargs = ({ + "sampling_metadata": model_input.sampling_metadata + } if "SwiftKV" in type(self.model).__name__ else {}) if not bypass_model_exec: with set_forward_context(model_input.attn_metadata, From a68890a038083a71b816af9d1531e4f46738074d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 9 Dec 2024 11:33:01 -0500 Subject: [PATCH 10/12] lint --- examples/swiftkv/run_eval_405b_fp8.sh | 2 +- examples/swiftkv/run_eval_8b.sh | 2 +- vllm/model_executor/models/llama_swiftkv.py | 7 ++----- vllm/transformers_utils/config.py | 13 ++++++------- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/swiftkv/run_eval_405b_fp8.sh b/examples/swiftkv/run_eval_405b_fp8.sh index 3f87b2fe682d4..e87d5e7654a6a 100644 --- a/examples/swiftkv/run_eval_405b_fp8.sh +++ b/examples/swiftkv/run_eval_405b_fp8.sh @@ -1,4 +1,4 @@ -#/usr/bin/env bash +#!/usr/bin/env bash MODEL=Snowflake/Llama-3.1-SwiftKV-405B-Instruct-FP8 diff --git a/examples/swiftkv/run_eval_8b.sh b/examples/swiftkv/run_eval_8b.sh index 2124215116b08..cb15e606b8fee 100644 --- a/examples/swiftkv/run_eval_8b.sh +++ b/examples/swiftkv/run_eval_8b.sh @@ -1,4 +1,4 @@ -#/usr/bin/env bash +#!/usr/bin/env bash MODEL=Snowflake/Llama-3.1-SwiftKV-8B-Instruct diff --git a/vllm/model_executor/models/llama_swiftkv.py b/vllm/model_executor/models/llama_swiftkv.py index ed8a580cf14eb..bf3a949013775 100644 --- a/vllm/model_executor/models/llama_swiftkv.py +++ b/vllm/model_executor/models/llama_swiftkv.py @@ -6,11 +6,6 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.vllm_flash_attn import ( - flash_attn_func, - flash_attn_varlen_func, - flash_attn_with_kvcache, -) from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -38,6 +33,8 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import LlamaSwiftKVConfig +from vllm.vllm_flash_attn import (flash_attn_func, flash_attn_varlen_func, + flash_attn_with_kvcache) @dataclass diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8483d59aaf3c0..b680aea9deef8 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -25,13 +25,12 @@ EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, - LlamaSwiftKVConfig, - MedusaConfig, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, NVLM_D_Config, - Olmo2Config, RWConfig, - SolarConfig, Telechat2Config, - UltravoxConfig) + LlamaSwiftKVConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + NVLM_D_Config, Olmo2Config, + RWConfig, SolarConfig, + Telechat2Config, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname From eac1d45ddb8e238b6355fa37aa6238ad48839401 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 9 Dec 2024 11:35:29 -0500 Subject: [PATCH 11/12] lint --- vllm/transformers_utils/configs/llama_swiftkv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/configs/llama_swiftkv.py b/vllm/transformers_utils/configs/llama_swiftkv.py index 5290e2d1a05ec..939e1e9b714e2 100644 --- a/vllm/transformers_utils/configs/llama_swiftkv.py +++ b/vllm/transformers_utils/configs/llama_swiftkv.py @@ -21,5 +21,5 @@ def __init__( ): super().__init__(**kwargs) self.swiftkv = swiftkv - self.num_key_value_layers = (num_key_value_layers or - self.num_hidden_layers) + self.num_key_value_layers = (num_key_value_layers + or self.num_hidden_layers) From e7cbfc453c016337f9e3389b50110cb8da072d5d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 9 Dec 2024 11:36:56 -0500 Subject: [PATCH 12/12] Update README.md --- examples/swiftkv/README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/swiftkv/README.md b/examples/swiftkv/README.md index d39cd164b579a..b8c88417b772d 100644 --- a/examples/swiftkv/README.md +++ b/examples/swiftkv/README.md @@ -10,11 +10,6 @@ For more details, see: ## Quickstart -Install vLLM from Snowflake-Labs: -```console -$ pip install git+https://github.com/snowflake-labs/vllm.git@swiftkv -``` - Run an example conversation using [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct): ```console $ python examples/swiftkv/offline_inference_swiftkv.py