Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpecDec] Remove Batch Expansion (2/3) #9298

Merged
merged 4 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions tests/spec_decode/test_scorer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random
from typing import List

import pytest
import torch

Expand All @@ -10,31 +13,45 @@
from .utils import create_batch, create_worker


def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
def create_proposal(propose_lens: List[int], vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
batch_size = len(propose_lens)
max_propose_len = max(propose_lens)
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)

proposal_token_ids = torch.full((batch_size, max_propose_len),
fill_value=-1,
device=device)
for i in range(batch_size):
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
proposal_probs[i][:propose_lens[i]], dim=-1)

propose_lens = torch.tensor(propose_lens, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)
propose_lens)


def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)
assert torch.equal(
score1.token_ids,
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"


@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
"""
seed = 0
block_size = 32
Expand All @@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
should_modify_greedy_probs_inplace = True

vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)

if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)

proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)
num_lookahead_slots=max_propose_len)

batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
Expand Down
7 changes: 2 additions & 5 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# Max number of query tokens for among request in the batch.
max_decode_query_len: Optional[int] = None

_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
Expand Down
69 changes: 42 additions & 27 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
# Maximum query length in the batch.
max_query_len: Optional[int]

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]

# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
Expand Down Expand Up @@ -173,9 +170,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
Expand All @@ -202,12 +199,14 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
decode_query_len=self.decode_query_len,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
query_start_loc=self.query_start_loc[self.num_prefills:]
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
Expand Down Expand Up @@ -413,9 +412,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens)
max_decode_query_len = max(decode_query_lens)
else:
decode_query_len = 1
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
Expand Down Expand Up @@ -468,7 +467,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
decode_query_len=decode_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
Expand Down Expand Up @@ -714,20 +713,37 @@ def unified_flash_attention(

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
_, num_head, head_dim = decode_query.shape
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
num_head, head_dim)
decode_output = flash_attn_with_kvcache(
q=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1:
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)

if prefill_output is None:
assert decode_output is not None
Expand All @@ -739,7 +755,6 @@ def unified_flash_attention(
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
Expand Down
7 changes: 2 additions & 5 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# so far).
context_lens_tensor: Optional[torch.Tensor]

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None

_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def graph_capture_get_metadata_for_batch(
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
decode_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
Expand Down
7 changes: 2 additions & 5 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None

# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
Expand Down
42 changes: 34 additions & 8 deletions vllm/spec_decode/mqa_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def score_proposals(
target_seq_id_start = max(
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
all_proposal_tokens = proposals.proposal_token_ids.tolist()
all_proposal_lengths = proposals.proposal_lens.tolist()
for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data_dict = seq_group_metadata.seq_data
Expand All @@ -27,7 +28,8 @@ def score_proposals(
seq_data: SequenceData = seq_data_dict[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
output_token_ids = seq_data.get_output_token_ids()
proposal_token_ids = all_proposal_tokens[i]
proposal_token_ids = all_proposal_tokens[
i][:all_proposal_lengths[i]]
new_output_token_ids = [*output_token_ids, *proposal_token_ids]

target_seq_id = target_seq_id_start + i
Expand Down Expand Up @@ -62,18 +64,42 @@ def score_proposals(

target_sampler_output = target_sampler_output[0]

bs, k = proposals.proposal_token_ids.shape
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)

all_probs = target_sampler_output.sampled_token_probs.reshape(
bs, k + 1, self._vocab_size)
all_logprobs = target_sampler_output.logprobs.reshape(
bs, k + 1, self._vocab_size)
k = execute_model_req.num_lookahead_slots
bs = len(execute_model_req.seq_group_metadata_list)
target_token_ids = target_sampler_output.sampled_token_ids
target_probs = target_sampler_output.sampled_token_probs
target_logprobs = target_sampler_output.logprobs
# If all requests have the same number of query tokens, we can avoid
# the for loop to build output for better performance.
if min(all_proposal_lengths) == k:
bs, _ = proposals.proposal_token_ids.shape
all_tokens = target_token_ids.reshape(bs, k + 1)
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
else:
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape,
self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
target_token_ids = target_token_ids.flatten()
start_loc = 0
for i, proposed_len in enumerate(all_proposal_lengths):
output_len = proposed_len + 1
end_loc = start_loc + output_len
all_tokens[
i, :output_len] = target_token_ids[start_loc:end_loc]
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
all_logprobs[
i, :output_len] = target_logprobs[start_loc:end_loc]
start_loc = end_loc

hidden_states = None
if target_sampler_output.hidden_states is not None:
hidden_states = target_sampler_output.hidden_states.reshape(
bs, (k + 1), -1)

return SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
Expand Down
6 changes: 0 additions & 6 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,6 @@ def create_worker(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")

if ngram_prompt_lookup_max > 0:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer.")

if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \
scorer_worker.model_config.max_model_len:
Expand Down