Skip to content

Commit

Permalink
add test and fix a few bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
aurickq committed Dec 7, 2024
1 parent 4f3d05a commit bb43a8b
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 94 deletions.
Empty file added tests/swiftkv/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions tests/swiftkv/test_llama_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

import vllm
from tests.utils import multi_gpu_test
from vllm.sampling_params import SamplingParams

MODELS = ["Snowflake/Llama-3.1-SwiftKV-8B-Instruct-FP8"]
CONVERSATIONS = [
[{"role": "user", "content": "Hello!"}],
[{"role": "user", "content": "Who is the president of the United States?"}],
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "What is the future of AI?"}],
]
EXPECTED_OUTPUTS = [
"Hello! How can I assist you today?",
"As of my cut-off knowledge in December 2023, the President of the United "
"States is Joe",
"The capital of France is Paris.",
"The future of AI is vast and rapidly evolving, with numerous potential "
"developments and applications on the horizon.",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
@multi_gpu_test(num_gpus=2)
def test_model(model, enforce_eager, tensor_parallel_size) -> None:
llm = vllm.LLM(
model,
enforce_eager=enforce_eager,
enable_chunked_prefill=True,
tensor_parallel_size=tensor_parallel_size,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=20)

for idx, conversation in enumerate(CONVERSATIONS):
outputs = llm.chat(
conversation,
sampling_params=sampling_params,
use_tqdm=False,
)
assert outputs[0].outputs[0].text == EXPECTED_OUTPUTS[idx]
200 changes: 106 additions & 94 deletions vllm/model_executor/models/llama_swiftkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from vllm.model_executor.models.utils import (
AutoWeightsLoader, is_pp_missing_parameter, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import LlamaSwiftKVConfig

Expand Down Expand Up @@ -177,45 +178,27 @@ def forward(
value = value.view(-1, self.num_kv_heads, self.head_dim)

if attn_metadata.use_varlen:
if (kv_cache.numel() == 0 or attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_seq_len,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scaling,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
softcap=0,
)
else:
# prefix-enabled attention
attn_output = flash_attn_varlen_func( # noqa
q=query,
k=kv_cache[0],
v=kv_cache[1],
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scaling,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
block_table=attn_metadata.block_tables,
softcap=0,
)
# Should be neither capture nor profile run.
assert kv_cache.numel() and attn_metadata.block_tables.numel()
attn_output = flash_attn_varlen_func( # noqa
q=query,
k=kv_cache[0],
v=kv_cache[1],
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scaling,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
block_table=attn_metadata.block_tables,
softcap=0,
)
else:
assert attn_metadata.seq_lens.numel() == num_tokens
if kv_cache.numel():
assert attn_metadata.block_tables.numel()
attn_output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=kv_cache[0],
Expand All @@ -229,6 +212,8 @@ def forward(
softcap=0,
).squeeze(1)
else:
# For profile run, we don't have kv_cache and block_tables.
assert not attn_metadata.block_tables.numel()
attn_output = flash_attn_func(
q=query.unsqueeze(1),
k=key.unsqueeze(1),
Expand Down Expand Up @@ -337,6 +322,9 @@ def _padded_size(size: int) -> int:
class LlamaSwiftKVModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if not vllm_config.scheduler_config.chunked_prefill_enabled:
raise ValueError("SwiftKV requires chunked prefill to be enabled")

super().__init__()

config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -383,9 +371,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.cuda_graphs = {}
self.cuda_graph_max_batch_size = _padded_size(
vllm_config.scheduler_config.max_num_seqs)
max_seq_len = vllm_config.model_config.max_seq_len_to_capture
block_size = vllm_config.cache_config.block_size
self.cuda_graph_max_num_blocks = (
vllm_config.model_config.max_seq_len_to_capture //
vllm_config.cache_config.block_size)
(max_seq_len + block_size - 1) // block_size)
self.cuda_graph_tensors = {
"positions": torch.empty(self.cuda_graph_max_batch_size,
dtype=torch.long),
Expand All @@ -411,6 +400,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
dtype=torch.int32),
),
}
self.cuda_graph_pool = None
else:
self.use_inner_cuda_graph = False

Expand Down Expand Up @@ -483,22 +473,21 @@ def _get_swiftkv_metadata_for_cuda_graph(
seq_lens=attn_metadata.seq_lens_tensor,
)

def _prepare_cuda_graph_inputs(
def _prepare_cuda_graph(
self,
size: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor,
kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]],
swiftkv_metadata: SwiftKVMetadata,
):
size = hidden_states.size(0)
self.cuda_graph_tensors["positions"][:size].copy_(positions)
self.cuda_graph_tensors["hidden_states"][:size].copy_(hidden_states)
self.cuda_graph_tensors["residual"][:size].copy_(residual)
cuda_graph_kv_states = self.cuda_graph_tensors["kv_states"]
for layer_idx, (k, v) in kv_states.items():
cuda_graph_kv_states[layer_idx][0][:size].copy_(k)
cuda_graph_kv_states[layer_idx][1][:size].copy_(v)
for idx, (k, v) in kv_states.items():
self.cuda_graph_tensors["kv_states"][idx][0][:size].copy_(k)
self.cuda_graph_tensors["kv_states"][idx][1][:size].copy_(v)
cuda_graph_metadata = self.cuda_graph_tensors["metadata"]
cuda_graph_metadata.seq_lens[:size].copy_(swiftkv_metadata.seq_lens)
num_blocks = min(self.cuda_graph_max_num_blocks,
Expand All @@ -510,19 +499,17 @@ def _prepare_cuda_graph_inputs(
positions = self.cuda_graph_tensors["positions"][:padded_size]
hidden_states = self.cuda_graph_tensors["hidden_states"][:padded_size]
residual = self.cuda_graph_tensors["residual"][:padded_size]
for layer_idx in kv_states:
kv_states[layer_idx] = (
cuda_graph_kv_states[layer_idx][0][:padded_size],
cuda_graph_kv_states[layer_idx][1][:padded_size],
)
kv_states = {
idx: (k[:padded_size], v[:padded_size])
for idx, (k, v) in self.cuda_graph_tensors["kv_states"].items()
}
swiftkv_metadata = SwiftKVMetadata(
use_varlen=swiftkv_metadata.use_varlen,
indices=swiftkv_metadata.indices,
seq_lens=cuda_graph_metadata.seq_lens[:padded_size],
block_tables=cuda_graph_metadata.block_tables[:padded_size],
)
return (padded_size, positions, hidden_states, residual, kv_states,
swiftkv_metadata)
return positions, hidden_states, residual, kv_states, swiftkv_metadata

def _run_swiftkv_layers(
self,
Expand All @@ -549,6 +536,57 @@ def _run_swiftkv_layers(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

def _capture_cuda_graph(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor,
kv_states: Dict[int, Tuple[torch.Tensor, torch.Tensor]],
kv_caches: List[torch.Tensor],
swiftkv_metadata: SwiftKVMetadata,
) -> torch.cuda.graph:
positions, hidden_states, residual, kv_states, swiftkv_metadata = (
self._prepare_cuda_graph(
positions,
hidden_states,
residual,
kv_states,
swiftkv_metadata,
)
)
padded_size = _padded_size(hidden_states.size(0))
cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"]
with graph_capture() as ctx, torch.cuda.stream(ctx.stream):
graph = torch.cuda.CUDAGraph()
# Run a few times first to ensure the captured graph does not
# include kernel launches for initial benchmarking (e.g., Triton
# autotune). Note that once is not enough for torch.jit.script.
for _ in range(2):
cuda_graph_hidden_states[:padded_size].copy_(
self._run_swiftkv_layers(
positions,
hidden_states,
residual,
kv_states,
kv_caches,
swiftkv_metadata,
)
)
ctx.stream.synchronize()
with torch.cuda.graph(graph, stream=ctx.stream):
cuda_graph_hidden_states[:padded_size].copy_(
self._run_swiftkv_layers(
positions,
hidden_states,
residual,
kv_states,
kv_caches,
swiftkv_metadata,
)
)
self.cuda_graph_pool = graph.pool()
return graph

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down Expand Up @@ -616,62 +654,36 @@ def forward(
for layer_idx, (k, v) in kv_states.items()
}

batch_size = hidden_states.size(0)
size = hidden_states.size(0)
if (self.use_inner_cuda_graph and not attn_metadata.use_cuda_graph
and not swiftkv_metadata.use_varlen and kv_caches[0].numel()
and batch_size <= self.cuda_graph_max_batch_size
and size <= self.cuda_graph_max_batch_size
and swiftkv_metadata.block_tables.numel()
and swiftkv_metadata.block_tables.size(1) <=
self.cuda_graph_max_num_blocks
):
# We implement our own (just-in-time) cuda graph for the second
# half of the model (layers skipped for prefill tokens).
(
padded_size,
positions,
hidden_states,
residual,
kv_states,
swiftkv_metadata,
) = self._prepare_cuda_graph_inputs(
batch_size,
padded_size = _padded_size(size)
if padded_size not in self.cuda_graphs:
print("Capture SwiftKV CUDA graph for batch size", padded_size)
self.cuda_graphs[padded_size] = self._capture_cuda_graph(
positions,
hidden_states,
residual,
kv_states,
kv_caches,
swiftkv_metadata,
)
self._prepare_cuda_graph(
positions,
hidden_states,
residual,
kv_states,
swiftkv_metadata,
)
g = self.cuda_graphs.get(padded_size)
cuda_graph_hidden_states = self.cuda_graph_tensors["hidden_states"]
if g is None:
g = torch.cuda.CUDAGraph()
# Run a few times first to ensure the captured graph does not
# include kernel launches for initial benchmarking (e.g., Triton
# autotune). Note that once is not enough for torch.jit.script.
for _ in range(2):
h = self._run_swiftkv_layers(
positions,
hidden_states,
residual,
kv_states,
kv_caches,
swiftkv_metadata,
)
cuda_graph_hidden_states[:padded_size].copy_(h)
print("Capture SwiftKV CUDA graph for batch size", padded_size)
with graph_capture() as c, torch.cuda.graph(g, stream=c.stream):
hidden_states = self._run_swiftkv_layers(
positions,
hidden_states,
residual,
kv_states,
kv_caches,
swiftkv_metadata,
)
cuda_graph_hidden_states[:padded_size].copy_(hidden_states)
self.cuda_graphs[padded_size] = g
else:
g.replay()
hidden_states = cuda_graph_hidden_states[:batch_size]
self.cuda_graphs[padded_size].replay()
hidden_states.copy_(self.cuda_graph_tensors["hidden_states"][:size])
else:
hidden_states = self._run_swiftkv_layers(
positions,
Expand Down

0 comments on commit bb43a8b

Please sign in to comment.