2323from tensorrt_llm .llmapi .tokenizer import load_hf_tokenizer
2424
2525from ..conftest import (get_device_count , llm_models_root , parametrize_with_ids ,
26- skip_pre_hopper )
26+ skip_pre_blackwell , skip_pre_hopper )
2727from ..trt_test_alternative import popen
2828from .accuracy_core import (GSM8K , MMLU , JsonModeEval ,
2929 LlmapiAccuracyTestHarness , get_accuracy_task )
@@ -71,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7171 ctx_server_config : Dict [str , Any ],
7272 gen_server_config : Dict [str , Any ],
7373 model_name : str ,
74- tensor_parallel_size : int = 1 ):
74+ tensor_parallel_size : int = 1 ,
75+ ctx_model : str = None ,
76+ gen_model : str = None ):
7577 temp_dir = tempfile .TemporaryDirectory ()
7678 disaggregated_serving_config_path = os .path .join (
7779 temp_dir .name , "disaggregated_serving_config.yaml" )
@@ -97,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9799
98100 trtllm_serve_path = "trtllm-serve"
99101 # Common arguments for both servers
100- common_args = [
102+ ctx_model = ctx_model or model_name
103+ gen_model = gen_model or model_name
104+ ctx_args = [
101105 trtllm_serve_path ,
102- model_name ,
106+ ctx_model ,
107+ "--host" ,
108+ "localhost" ,
109+ "--backend" ,
110+ "pytorch" ,
111+ ]
112+ gen_args = [
113+ trtllm_serve_path ,
114+ gen_model ,
103115 "--host" ,
104116 "localhost" ,
105117 "--backend" ,
@@ -125,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
125137 env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
126138 env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
127139 map (str , range (ctx_total_gpus , ctx_total_gpus + gen_total_gpus )))
128- ctx_server_args = common_args + [
140+ ctx_server_args = ctx_args + [
129141 "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path ,
130142 f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
131143 ]
132- gen_server_args = common_args + [
144+ gen_server_args = gen_args + [
133145 "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path ,
134146 f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
135147 ]
@@ -226,17 +238,21 @@ def generate_async(prompt: str,
226238 disaggregated_server .wait ()
227239
228240
229- def run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
230- ctx_tp : int , gen_pp : int , gen_tp : int ,
231- test_set : LlmapiAccuracyTestHarness ):
241+ def run_parallel_test (model_name : str ,
242+ model_path : str ,
243+ ctx_pp : int ,
244+ ctx_tp : int ,
245+ gen_pp : int ,
246+ gen_tp : int ,
247+ test_sets : List [LlmapiAccuracyTestHarness ],
248+ ctx_model : str = None ,
249+ gen_model : str = None ):
232250 if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count ():
233251 pytest .fail (
234252 f"Not enough devices for ctx_pp={ ctx_pp } +ctx_tp={ ctx_tp } and gen_pp={ gen_pp } +gen_tp={ gen_tp } test"
235253 )
236-
237254 kv_cache_config = {
238255 "free_gpu_memory_fraction" : 0.5 ,
239- "enable_block_reuse" : False
240256 }
241257 ctx_server_config = {
242258 "pipeline_parallel_size" : ctx_pp ,
@@ -270,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
270286 }
271287 }
272288 with launch_disaggregated_llm (disaggregated_server_config ,
273- ctx_server_config , gen_server_config ,
274- model_path ) as llm :
275- task = test_set (model_name )
276- task .evaluate (llm )
289+ ctx_server_config ,
290+ gen_server_config ,
291+ model_path ,
292+ ctx_model = ctx_model ,
293+ gen_model = gen_model ) as llm :
294+ for test_set in test_sets :
295+ task = test_set (model_name )
296+ task .evaluate (llm )
277297
278298
279299@pytest .mark .timeout (3600 )
@@ -511,14 +531,14 @@ def test_guided_decoding_with_eagle3(self, backend: str, mocker):
511531 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
512532 def test_tp_pp_symmetric (self , tp , pp , testset ):
513533 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
514- tp , get_accuracy_task (testset ))
534+ tp , [ get_accuracy_task (testset )] )
515535
516536 @parametrize_with_ids ("ctx_pp" , [2 , 4 ])
517537 @parametrize_with_ids ("gen_tp" , [1 , 2 ])
518538 @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
519539 def test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
520540 return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
521- gen_tp , get_accuracy_task (testset ))
541+ gen_tp , [ get_accuracy_task (testset )] )
522542
523543
524544@pytest .mark .skip_less_device_memory (140000 )
@@ -702,3 +722,24 @@ def test_auto_dtype(self, overlap_scheduler):
702722 task .evaluate (llm )
703723 task = MMLU (self .MODEL_NAME )
704724 task .evaluate (llm )
725+
726+
727+ @skip_pre_blackwell
728+ @pytest .mark .timeout (3600 )
729+ class TestQwen3_30B_A3B (LlmapiAccuracyTestHarness ):
730+ fp4_model = f"{ llm_models_root ()} /Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf"
731+ fp8_model = f"{ llm_models_root ()} /Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf"
732+
733+ @pytest .mark .parametrize ("ctxpp,gentp" , [(2 , 2 )], ids = ["ctxpp2gentp2" ])
734+ def test_mixed_ctx_gen_model (self , ctxpp , gentp ):
735+ ctx_model = self .fp4_model
736+ gen_model = self .fp8_model
737+ return run_parallel_test ("Qwen3/Qwen3-30B-A3B" ,
738+ ctx_model ,
739+ ctx_pp = ctxpp ,
740+ ctx_tp = 1 ,
741+ gen_pp = 1 ,
742+ gen_tp = gentp ,
743+ test_sets = [GSM8K , MMLU ],
744+ ctx_model = ctx_model ,
745+ gen_model = gen_model )
0 commit comments