Skip to content

Commit

Permalink
Rework, including spec decoding cases
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Jul 26, 2024
1 parent 3af88df commit d2880ea
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 88 deletions.
1 change: 1 addition & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
print("CREATE LLM GENERATOR")
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
Expand Down
35 changes: 22 additions & 13 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import pytest

from .conftest import run_greedy_equality_correctness_test, run_equality_correctness_test
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)

# main model
MAIN_MODEL = "JackFram/llama-160m"
Expand Down Expand Up @@ -94,31 +95,39 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
# Main model
"model": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int, temperature: float):
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)

# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"output_len",
[
# Use smaller output len for fast test.
10,
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
Expand Down
95 changes: 39 additions & 56 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.jit
Expand Down Expand Up @@ -36,7 +36,7 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
Expand Down Expand Up @@ -66,6 +66,9 @@ def forward(
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
Expand All @@ -83,7 +86,7 @@ def forward(
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))

output_token_ids = self._create_output(
Expand All @@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Expand All @@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(

# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)

recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

return accepted, recovered_token_ids
Expand All @@ -143,7 +140,7 @@ def _get_accepted(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
Expand Down Expand Up @@ -178,24 +175,26 @@ def _get_accepted(
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)

if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)

for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])

if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
Expand Down Expand Up @@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
"""
return torch.finfo(self.probs_dtype).tiny

# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:

if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))

return seed_indices, non_seed_indices


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
Expand All @@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:

if num_samples > 1:
Expand All @@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])

q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)

return probs.div_(q).argmax(dim=1).view(-1, num_samples)
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, Optional

import torch
import torch.jit
Expand Down Expand Up @@ -237,6 +237,6 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError
19 changes: 17 additions & 2 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
Expand All @@ -15,6 +16,8 @@
TargetSeqId = int
TokenId = int

DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()


class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
Expand Down Expand Up @@ -246,14 +249,25 @@ def _create_target_seq_group_metadata(
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])

# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params

target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=target_sampling_params,
))

return target_seq_group_metadata_list
Expand All @@ -264,6 +278,7 @@ def _create_single_target_seq_group_metadata(
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Expand Down Expand Up @@ -296,7 +311,7 @@ def _create_single_target_seq_group_metadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
Expand Down
4 changes: 3 additions & 1 deletion vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def sampler_output(
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)

generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)

model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
Expand Down
4 changes: 3 additions & 1 deletion vllm/spec_decode/mlp_speculator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def sampler_output(
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)

generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)

model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,
Expand Down
3 changes: 1 addition & 2 deletions vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase


class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
Expand Down
Loading

0 comments on commit d2880ea

Please sign in to comment.