Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
aurickq committed Dec 9, 2024
1 parent 16d45f2 commit 95a264d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 96 deletions.
2 changes: 2 additions & 0 deletions examples/swiftkv/run_eval_405b_fp8.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#/usr/bin/env bash

MODEL=Snowflake/Llama-3.1-SwiftKV-405B-Instruct-FP8

EVAL_CMD=$(cat <<EOF
Expand Down
2 changes: 2 additions & 0 deletions examples/swiftkv/run_eval_8b.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#/usr/bin/env bash

MODEL=Snowflake/Llama-3.1-SwiftKV-8B-Instruct

EVAL_CMD=$(cat <<EOF
Expand Down
20 changes: 16 additions & 4 deletions tests/swiftkv/test_llama_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@

MODELS = ["Snowflake/Llama-3.1-SwiftKV-8B-Instruct-FP8"]
CONVERSATIONS = [
[{"role": "user", "content": "Hello!"}],
[{"role": "user", "content": "Who is the president of the United States?"}],
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "What is the future of AI?"}],
[{
"role": "user",
"content": "Hello!"
}],
[{
"role": "user",
"content": "Who is the president of the United States?"
}],
[{
"role": "user",
"content": "What is the capital of France?"
}],
[{
"role": "user",
"content": "What is the future of AI?"
}],
]
EXPECTED_OUTPUTS = [
"Hello! How can I assist you today?",
Expand Down
146 changes: 64 additions & 82 deletions vllm/model_executor/models/llama_swiftkv.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,20 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.vllm_flash_attn import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -55,8 +31,9 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaMLP
from vllm.model_executor.models.utils import (
AutoWeightsLoader, is_pp_missing_parameter, maybe_prefix)
from vllm.model_executor.models.utils import (AutoWeightsLoader,
is_pp_missing_parameter,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -316,7 +293,7 @@ def _padded_size(size: int) -> int:
mult = (1 << (size - 1).bit_length()) // 4
if mult < 1:
return size
return (size + mult - 1) // mult * mult
return (size + mult - 1) // mult * mult


class LlamaSwiftKVModel(nn.Module):
Expand All @@ -331,9 +308,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.kv_cache_dtype = (
cache_config.cache_dtype if cache_config is not None else "auto"
)
self.kv_cache_dtype = (cache_config.cache_dtype
if cache_config is not None else "auto")

self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -352,15 +328,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}")
if idx < config.num_key_value_layers
else LlamaSwiftKVDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}")
if idx < config.num_key_value_layers else LlamaSwiftKVDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}")
for idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_swiftkv = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

# Cuda graph inputs/output tensors
if not vllm_config.model_config.enforce_eager:
Expand All @@ -373,31 +350,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.scheduler_config.max_num_seqs)
max_seq_len = vllm_config.model_config.max_seq_len_to_capture
block_size = vllm_config.cache_config.block_size
self.cuda_graph_max_num_blocks = (
(max_seq_len + block_size - 1) // block_size)
self.cuda_graph_max_num_blocks = ((max_seq_len + block_size - 1) //
block_size)
self.cuda_graph_tensors = {
"positions": torch.empty(self.cuda_graph_max_batch_size,
dtype=torch.long),
"hidden_states": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size),
"residual": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size),
"positions":
torch.empty(self.cuda_graph_max_batch_size, dtype=torch.long),
"hidden_states":
torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size),
"residual":
torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size),
"kv_states": {
layer_idx: (
torch.empty(self.cuda_graph_max_batch_size, kv_size),
torch.empty(self.cuda_graph_max_batch_size, kv_size),
)
for layer_idx in range(config.num_key_value_layers,
config.num_hidden_layers)
config.num_hidden_layers)
},
"metadata": SwiftKVMetadata(
"metadata":
SwiftKVMetadata(
use_varlen=False,
indices=None,
seq_lens=torch.empty(self.cuda_graph_max_batch_size,
dtype=torch.int32),
dtype=torch.int32),
block_tables=torch.empty(self.cuda_graph_max_batch_size,
self.cuda_graph_max_num_blocks,
dtype=torch.int32),
self.cuda_graph_max_num_blocks,
dtype=torch.int32),
),
}
self.cuda_graph_pool = None
Expand All @@ -422,8 +402,8 @@ def _get_swiftkv_metadata(
for seq_id in range(len(query_start_loc) - 1):
seq_begin = query_start_loc[seq_id]
seq_end = query_start_loc[seq_id + 1]
while (idx < len(sampling_indices) and
sampling_indices[idx] < seq_begin):
while (idx < len(sampling_indices)
and sampling_indices[idx] < seq_begin):
idx += 1
if idx >= len(sampling_indices):
break
Expand All @@ -442,19 +422,22 @@ def _get_swiftkv_metadata(
use_varlen=False,
indices=torch.tensor(swiftkv_indices, device=device),
block_tables=attn_metadata.block_tables[swiftkv_seq_ids],
seq_lens=torch.tensor(swiftkv_seq_lens, device=device,
dtype=torch.int32),
seq_lens=torch.tensor(swiftkv_seq_lens,
device=device,
dtype=torch.int32),
)
else:
return SwiftKVMetadata(
use_varlen=True,
indices=torch.tensor(swiftkv_indices, device=device),
block_tables=attn_metadata.block_tables[swiftkv_seq_ids],
query_start_loc=torch.tensor(
[0] + swiftkv_query_lens, device=device,
[0] + swiftkv_query_lens,
device=device,
).cumsum(dim=0).to(torch.int32),
seq_start_loc=torch.tensor(
[0] + swiftkv_seq_lens, device=device,
[0] + swiftkv_seq_lens,
device=device,
).cumsum(dim=0).to(torch.int32),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
Expand All @@ -464,8 +447,8 @@ def _get_swiftkv_metadata_for_cuda_graph(
self,
attn_metadata: FlashAttentionMetadata,
) -> SwiftKVMetadata:
assert (attn_metadata.num_prefills == 0 and
attn_metadata.max_decode_query_len == 1)
assert (attn_metadata.num_prefills == 0
and attn_metadata.max_decode_query_len == 1)
return SwiftKVMetadata(
use_varlen=False,
indices=None,
Expand Down Expand Up @@ -552,8 +535,7 @@ def _capture_cuda_graph(
residual,
kv_states,
swiftkv_metadata,
)
)
))
padded_size = _padded_size(hidden_states.size(0))
cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"]
with graph_capture() as ctx, torch.cuda.stream(ctx.stream):
Expand All @@ -570,8 +552,7 @@ def _capture_cuda_graph(
kv_states,
kv_caches,
swiftkv_metadata,
)
)
))
ctx.stream.synchronize()
with torch.cuda.graph(graph, stream=ctx.stream):
cuda_graph_hidden_states[:padded_size].copy_(
Expand All @@ -582,8 +563,7 @@ def _capture_cuda_graph(
kv_states,
kv_caches,
swiftkv_metadata,
)
)
))
self.cuda_graph_pool = graph.pool()
return graph

Expand All @@ -599,9 +579,8 @@ def forward(
) -> Union[torch.Tensor, IntermediateTensors]:
swiftkv_metadata = (
self._get_swiftkv_metadata(attn_metadata, sampling_metadata)
if not attn_metadata.use_cuda_graph
else self._get_swiftkv_metadata_for_cuda_graph(attn_metadata)
)
if not attn_metadata.use_cuda_graph else
self._get_swiftkv_metadata_for_cuda_graph(attn_metadata))

if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand Down Expand Up @@ -638,7 +617,8 @@ def forward(
kv_caches[layer_idx][1],
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
1.0, 1.0,
1.0,
1.0,
)

if swiftkv_metadata.indices is not None:
Expand All @@ -649,19 +629,18 @@ def forward(
residual = residual[swiftkv_metadata.indices]
positions = positions[swiftkv_metadata.indices]
kv_states = {
layer_idx: (k[swiftkv_metadata.indices],
v[swiftkv_metadata.indices])
layer_idx:
(k[swiftkv_metadata.indices], v[swiftkv_metadata.indices])
for layer_idx, (k, v) in kv_states.items()
}

size = hidden_states.size(0)
if (self.use_inner_cuda_graph and not attn_metadata.use_cuda_graph
and not swiftkv_metadata.use_varlen and kv_caches[0].numel()
and size <= self.cuda_graph_max_batch_size
and swiftkv_metadata.block_tables.numel()
and swiftkv_metadata.block_tables.size(1) <=
self.cuda_graph_max_num_blocks
):
and not swiftkv_metadata.use_varlen and kv_caches[0].numel()
and size <= self.cuda_graph_max_batch_size
and swiftkv_metadata.block_tables.numel()
and swiftkv_metadata.block_tables.size(1) <=
self.cuda_graph_max_num_blocks):
# We implement our own (just-in-time) cuda graph for the second
# half of the model (layers skipped for prefill tokens).
padded_size = _padded_size(size)
Expand All @@ -683,7 +662,8 @@ def forward(
swiftkv_metadata,
)
self.cuda_graphs[padded_size].replay()
hidden_states.copy_(self.cuda_graph_tensors["hidden_states"][:size])
hidden_states.copy_(
self.cuda_graph_tensors["hidden_states"][:size])
else:
hidden_states = self._run_swiftkv_layers(
positions,
Expand Down Expand Up @@ -871,8 +851,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
config.vocab_size, logit_scale)
self.sampler = Sampler()

def forward(
Expand All @@ -884,8 +863,11 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
sampling_metadata: Optional[SamplingMetadata] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
model_output = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
sampling_metadata=sampling_metadata)
return model_output

Expand Down
10 changes: 2 additions & 8 deletions vllm/transformers_utils/configs/llama_swiftkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ class LlamaSwiftKVConfig(LlamaConfig):
num_key_value_layers (int, optional):
The number of layers, from the first layer, that have keys and
values. If None, all layers have keys and values.
last_key_value_heads (int, optional):
The number of heads in the last layer that have keys and values.
If None, the number of heads in the last key-value layer is equal
to the number of heads in all the other key-value layers.
"""

model_type = "llama_swiftkv"
Expand All @@ -21,11 +17,9 @@ def __init__(
self,
swiftkv: bool = False,
num_key_value_layers: Optional[int] = None,
key_value_group_size: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.swiftkv = swiftkv
self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers
self.key_value_group_size = key_value_group_size or 1
assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0
self.num_key_value_layers = (num_key_value_layers or
self.num_hidden_layers)
5 changes: 3 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,8 +1673,9 @@ def execute_model(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()

swiftkv_kwargs = ({"sampling_metadata": model_input.sampling_metadata}
if "SwiftKV" in type(self.model).__name__ else {})
swiftkv_kwargs = ({
"sampling_metadata": model_input.sampling_metadata
} if "SwiftKV" in type(self.model).__name__ else {})

if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
Expand Down

0 comments on commit 95a264d

Please sign in to comment.