Skip to content

Commit a8c8005

Browse files
committed
add mixed gen/ctx model test
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 49bcaa4 commit a8c8005

File tree

4 files changed

+61
-17
lines changed

4 files changed

+61
-17
lines changed

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Qwen3/Qwen3-8B:
8080
kv_cache_quant_algo: FP8
8181
accuracy: 87.1114
8282
Qwen3/Qwen3-30B-A3B:
83+
- accuracy: 83.43
8384
- quant_algo: FP8_BLOCK_SCALES
8485
accuracy: 84.36
8586
- quant_algo: FP8

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
2424

2525
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
26-
skip_pre_hopper)
26+
skip_pre_blackwell, skip_pre_hopper)
2727
from ..trt_test_alternative import popen
2828
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
2929
LlmapiAccuracyTestHarness, get_accuracy_task)
@@ -71,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7171
ctx_server_config: Dict[str, Any],
7272
gen_server_config: Dict[str, Any],
7373
model_name: str,
74-
tensor_parallel_size: int = 1):
74+
tensor_parallel_size: int = 1,
75+
ctx_model: str = None,
76+
gen_model: str = None):
7577
temp_dir = tempfile.TemporaryDirectory()
7678
disaggregated_serving_config_path = os.path.join(
7779
temp_dir.name, "disaggregated_serving_config.yaml")
@@ -97,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9799

98100
trtllm_serve_path = "trtllm-serve"
99101
# Common arguments for both servers
100-
common_args = [
102+
ctx_model = ctx_model or model_name
103+
gen_model = gen_model or model_name
104+
ctx_args = [
101105
trtllm_serve_path,
102-
model_name,
106+
ctx_model,
107+
"--host",
108+
"localhost",
109+
"--backend",
110+
"pytorch",
111+
]
112+
gen_args = [
113+
trtllm_serve_path,
114+
gen_model,
103115
"--host",
104116
"localhost",
105117
"--backend",
@@ -125,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
125137
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
126138
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
127139
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
128-
ctx_server_args = common_args + [
140+
ctx_server_args = ctx_args + [
129141
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
130142
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
131143
]
132-
gen_server_args = common_args + [
144+
gen_server_args = gen_args + [
133145
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
134146
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
135147
]
@@ -226,17 +238,21 @@ def generate_async(prompt: str,
226238
disaggregated_server.wait()
227239

228240

229-
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
230-
ctx_tp: int, gen_pp: int, gen_tp: int,
231-
test_set: LlmapiAccuracyTestHarness):
241+
def run_parallel_test(model_name: str,
242+
model_path: str,
243+
ctx_pp: int,
244+
ctx_tp: int,
245+
gen_pp: int,
246+
gen_tp: int,
247+
test_sets: List[LlmapiAccuracyTestHarness],
248+
ctx_model: str = None,
249+
gen_model: str = None):
232250
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
233251
pytest.fail(
234252
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
235253
)
236-
237254
kv_cache_config = {
238255
"free_gpu_memory_fraction": 0.5,
239-
"enable_block_reuse": False
240256
}
241257
ctx_server_config = {
242258
"pipeline_parallel_size": ctx_pp,
@@ -270,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
270286
}
271287
}
272288
with launch_disaggregated_llm(disaggregated_server_config,
273-
ctx_server_config, gen_server_config,
274-
model_path) as llm:
275-
task = test_set(model_name)
276-
task.evaluate(llm)
289+
ctx_server_config,
290+
gen_server_config,
291+
model_path,
292+
ctx_model=ctx_model,
293+
gen_model=gen_model) as llm:
294+
for test_set in test_sets:
295+
task = test_set(model_name)
296+
task.evaluate(llm)
277297

278298

279299
@pytest.mark.timeout(3600)
@@ -511,14 +531,14 @@ def test_guided_decoding_with_eagle3(self, backend: str, mocker):
511531
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
512532
def test_tp_pp_symmetric(self, tp, pp, testset):
513533
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
514-
tp, get_accuracy_task(testset))
534+
tp, [get_accuracy_task(testset)])
515535

516536
@parametrize_with_ids("ctx_pp", [2, 4])
517537
@parametrize_with_ids("gen_tp", [1, 2])
518538
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
519539
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
520540
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
521-
gen_tp, get_accuracy_task(testset))
541+
gen_tp, [get_accuracy_task(testset)])
522542

523543

524544
@pytest.mark.skip_less_device_memory(140000)
@@ -702,3 +722,24 @@ def test_auto_dtype(self, overlap_scheduler):
702722
task.evaluate(llm)
703723
task = MMLU(self.MODEL_NAME)
704724
task.evaluate(llm)
725+
726+
727+
@skip_pre_blackwell
728+
@pytest.mark.timeout(3600)
729+
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
730+
fp4_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf"
731+
fp8_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf"
732+
733+
@pytest.mark.parametrize("ctxpp,gentp", [(2, 2)], ids=["ctxpp2gentp2"])
734+
def test_mixed_ctx_gen_model(self, ctxpp, gentp):
735+
ctx_model = self.fp4_model
736+
gen_model = self.fp8_model
737+
return run_parallel_test("Qwen3/Qwen3-30B-A3B",
738+
ctx_model,
739+
ctx_pp=ctxpp,
740+
ctx_tp=1,
741+
gen_pp=1,
742+
gen_tp=gentp,
743+
test_sets=[GSM8K, MMLU],
744+
ctx_model=ctx_model,
745+
gen_model=gen_model)

tests/integration/test_lists/qa/llm_function_sanity.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen
2525
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
2626
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
2727
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
28+
accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
2829
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=FLASHINFER]
2930
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=TRTLLM]
3031
accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ l0_dgx_b200:
6868
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True]
6969
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True]
7070
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True]
71+
- accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
7172
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
7273
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS]
7374
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM]

0 commit comments

Comments
 (0)