Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] EAGLE Implementation with Top-1 proposer #6830

Merged
merged 54 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
3c590b1
initial changes to support EAGLE
abhigoyal1997 Jul 26, 2024
5f5bed1
handling hidden_states in case of bonus tokens since EAGLE will need it
abhigoyal1997 Jul 26, 2024
023e72d
enabling CUDA graph
abhigoyal1997 Jul 26, 2024
8ac1570
adding E2E test and formatting
abhigoyal1997 Jul 26, 2024
b379948
minor bug fix in graph capture
abhigoyal1997 Jul 26, 2024
aef9c00
fixing broadcasting of hidden states in distributed worker
abhigoyal1997 Jul 26, 2024
c8d63bd
formatting
abhigoyal1997 Jul 26, 2024
733ca4f
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 Jul 27, 2024
1a0aa60
formatting
abhigoyal1997 Jul 27, 2024
83b3dd8
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 Jul 31, 2024
b1f05ac
Masking position=0 in inputs for EAGLE
abhigoyal1997 Jul 31, 2024
bdee07c
reformatting
abhigoyal1997 Jul 31, 2024
441374f
Fixing the order of execution for scorer and proposer in non-driver w…
abhigoyal1997 Jul 31, 2024
0d1cbae
Adding hidden state propagation to _execute_model_spmd
abhigoyal1997 Aug 1, 2024
b60384a
Adding CUDA graph tests for medusa and eagle. Renaming mlp to medusa …
abhigoyal1997 Aug 1, 2024
7b6a0e6
Moving hidden states shift to spec_decode_worker
abhigoyal1997 Aug 1, 2024
9d806b3
formatting
abhigoyal1997 Aug 1, 2024
e1e3175
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 1, 2024
8db174f
Adding vocab truncation to EAGLE
abhigoyal1997 Aug 2, 2024
b6b0548
Minor changes and fixes. Adding expand model request and hidden state…
abhigoyal1997 Aug 2, 2024
f9cbd49
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 2, 2024
89184a1
Merge branch 'main' into eagle
abhigoyal1997 Aug 3, 2024
cf8b685
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 4, 2024
3c24f4b
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 5, 2024
a94ea89
Merge branch 'main' into eagle
abhigoyal1997 Aug 5, 2024
eaa586c
Removing commented code and a minor comment fix
abhigoyal1997 Aug 6, 2024
38e2b5c
formatting
abhigoyal1997 Aug 6, 2024
2f17900
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 6, 2024
c5f8d15
adding comments to clarify compatibility of eagle checkpoint in eagle.py
abhigoyal1997 Aug 7, 2024
53ab660
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 Aug 7, 2024
7f46c68
fixing model_cls resolution in eagle
abhigoyal1997 Aug 7, 2024
5e5d214
fixing model_cls resolution in eagle
abhigoyal1997 Aug 7, 2024
17c0fc6
Merge branch 'main' into eagle
abhigoyal1997 Aug 9, 2024
ad04e7f
adding doctrings to EAGLE and Medusa models
abhigoyal1997 Aug 13, 2024
90bee1d
fixing hidden states handling in batch expansion
abhigoyal1997 Aug 14, 2024
88c20e6
making HiddenStates a dataclass and renaming last_non_bonus_hidden_st…
abhigoyal1997 Aug 14, 2024
3ff257b
Merge branch 'hidden_states_fix' of github.fkinternal.com:abhinav-goy…
abhigoyal1997 Aug 14, 2024
1753d9a
reformatting
abhigoyal1997 Aug 14, 2024
99484ae
adding acceptance rate test for large output length
abhigoyal1997 Aug 16, 2024
2e51385
fixing hidden states manipulation for batch expansion
abhigoyal1997 Aug 16, 2024
faa2e28
Merge branch 'hidden_states_fix' of github.fkinternal.com:abhinav-goy…
abhigoyal1997 Aug 16, 2024
d8bcff0
print acceptance rate in spec decode tests
abhigoyal1997 Aug 16, 2024
1654d4d
Updating HiddenStates to handle prefill step as well
abhigoyal1997 Aug 16, 2024
6954ead
changing expected acceptance rate for test
abhigoyal1997 Aug 17, 2024
08b3cd5
Merge branch 'vllm-project:main' into hidden_states_fix
abhigoyal1997 Aug 17, 2024
5815ccc
Merge branch 'vllm-project:main' into hidden_states_fix
abhigoyal1997 Aug 18, 2024
601c816
Merge branch 'hidden_states_fix' into eagle
abhigoyal1997 Aug 19, 2024
df87143
Adding explanation for trucated vocab and merging main
abhigoyal1997 Aug 19, 2024
f906cef
formatting
abhigoyal1997 Aug 19, 2024
2147583
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 Aug 20, 2024
3febb95
Fixing compatibility of `worker.multi_step_worker.MultiStepWorker` wi…
abhigoyal1997 Aug 20, 2024
90582e2
Merge branch 'main' into eagle
abhigoyal1997 Aug 21, 2024
af5552b
Merge branch 'main' into eagle
abhigoyal1997 Aug 22, 2024
284468d
adding comment
abhigoyal1997 Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""This docstring details important information on the testing methodology.

Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.

Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.

However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.

With those tests, we can say at least, EAGLE would not break the
correctess for the target model outputs.
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
"""

import pytest

from .conftest import run_greedy_equality_correctness_test

# main model
MAIN_MODEL = "JackFram/llama-68m"

# speculative model
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"

# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 4

# precision
PRECISION = "float32"
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


abhigoyal1997 marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
import pytest
pytest.main([__file__])
68 changes: 56 additions & 12 deletions tests/spec_decode/e2e/test_medusa_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
Expand All @@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down Expand Up @@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
Expand Down Expand Up @@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
Expand Down Expand Up @@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
Expand Down
Loading
Loading