From 8aea36d9f1f73b46843e3bdeaa6f79cbb5a5bdb2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 24 Jul 2024 08:58:31 -0700 Subject: [PATCH] [Bugfix] Fix speculative decode seeded test (#6743) Signed-off-by: Alvant --- tests/spec_decode/e2e/conftest.py | 3 ++- tests/spec_decode/e2e/test_seed.py | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index bd1ea43f0b101..f9f246436c0f7 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -191,7 +191,8 @@ def generator_inner(): and llm.llm_engine.log_stats): for sate_logger in llm.llm_engine.stat_loggers.values(): sate_logger.local_interval = 0 - set_random_seed(seed) + if seed is not None: + set_random_seed(seed) yield llm del llm diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index 792d7cba0f270..394a53f03ed46 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -21,7 +21,8 @@ "num_speculative_tokens": 3, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) @pytest.mark.parametrize("batch_size", [1, 8, 32]) @pytest.mark.parametrize("temperature", [0.1, 1.0]) @pytest.mark.parametrize( @@ -30,15 +31,26 @@ # Use smaller output len for fast test. 10, ]) -@pytest.mark.parametrize("seed", [1]) -def test_seeded_consistency(baseline_llm_generator, batch_size: int, - temperature: float, output_len: int): +@pytest.mark.parametrize("seed", [None]) +def test_seeded_consistency(baseline_llm_generator, test_llm_generator, + batch_size: int, temperature: float, + output_len: int): """Verify outputs are consistent across multiple runs with same seed """ run_equality_correctness_test(baseline_llm_generator, - baseline_llm_generator, + test_llm_generator, batch_size, max_output_len=output_len, temperature=temperature, seeded=True, force_output_len=True) + + # Ensure this same test does fail if we _don't_ include per-request seeds + with pytest.raises(AssertionError): + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=temperature, + seeded=False, + force_output_len=True)