@@ -228,9 +228,11 @@ def test_fp8_beam_search(self):
228228 sampling_params = sampling_params ,
229229 extra_acc_spec = "beam_width=4" )
230230
231- def test_eagle3 (self ):
231+ @pytest .mark .parametrize (("overlap_scheduler" , "eagle3_one_model" ),
232+ [(False , True ), (False , False )])
233+ def test_eagle3 (self , overlap_scheduler , eagle3_one_model ):
232234 pytorch_config = dict (
233- disable_overlap_scheduler = True ,
235+ disable_overlap_scheduler = not overlap_scheduler ,
234236 cuda_graph_config = CudaGraphConfig (batch_sizes = [1 ]),
235237 )
236238 kv_cache_config = KvCacheConfig (enable_block_reuse = False )
@@ -240,7 +242,8 @@ def test_eagle3(self):
240242
241243 draft_len = 4
242244 spec_config = EagleDecodingConfig (max_draft_len = draft_len ,
243- speculative_model_dir = eagle_model_dir )
245+ speculative_model_dir = eagle_model_dir ,
246+ eagle3_one_model = eagle3_one_model )
244247
245248 with LLM (model = target_model_dir ,
246249 ** pytorch_config ,
@@ -249,6 +252,8 @@ def test_eagle3(self):
249252 build_config = None ) as llm :
250253 task = MMLU (self .MODEL_NAME )
251254 task .evaluate (llm )
255+ task = GSM8K (self .MODEL_NAME )
256+ task .evaluate (llm )
252257
253258 def test_ngram (self ):
254259 pytorch_config = dict (disable_overlap_scheduler = True )
@@ -1641,9 +1646,11 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
16411646 task = MMLU (self .MODEL_NAME )
16421647 task .evaluate (llm )
16431648
1644- def test_eagle3 (self ):
1649+ @pytest .mark .parametrize (("overlap_scheduler" , "eagle3_one_model" ),
1650+ [(False , True ), (False , False )])
1651+ def test_eagle3 (self , overlap_scheduler , eagle3_one_model ):
16451652 pytorch_config = dict (
1646- disable_overlap_scheduler = True ,
1653+ disable_overlap_scheduler = not overlap_scheduler ,
16471654 cuda_graph_config = CudaGraphConfig (batch_sizes = [1 ]),
16481655 )
16491656 kv_cache_config = KvCacheConfig (enable_block_reuse = False )
@@ -1653,7 +1660,8 @@ def test_eagle3(self):
16531660
16541661 draft_len = 4
16551662 spec_config = EagleDecodingConfig (max_draft_len = draft_len ,
1656- speculative_model_dir = eagle_model_dir )
1663+ speculative_model_dir = eagle_model_dir ,
1664+ eagle3_one_model = eagle3_one_model )
16571665
16581666 llm = LLM (model = target_model_dir ,
16591667 ** pytorch_config ,
0 commit comments