Skip to content

Commit

Permalink
adding acceptance rate test for large output length
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigoyal1997 committed Aug 16, 2024
1 parent 9587b05 commit 99484ae
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
23 changes: 14 additions & 9 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted=ensure_all_accepted)


def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
def run_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
Expand Down Expand Up @@ -359,3 +361,6 @@ def run_equality_correctness_test(baseline_llm_generator,

if ensure_all_accepted:
assert acceptance_rate == 1.0

if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
42 changes: 42 additions & 0 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=0.0,
seeded=True,
force_output_len=True,
expected_acceptance_rate=0.6)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down

0 comments on commit 99484ae

Please sign in to comment.