@@ -489,6 +489,51 @@ def test_chunked_prefill(self, attn_backend):
489489 task = MMLU (self .MODEL_NAME )
490490 task .evaluate (llm )
491491
492+ @skip_pre_hopper
493+ @pytest .mark .skip_less_mpi_world_size (8 )
494+ @parametrize_with_ids ("cuda_graph" , [False , True ])
495+ @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ),
496+ (8 , 1 , 8 )],
497+ ids = ["tp8" , "tp8ep4" , "tp8ep8" ])
498+ def test_fp8 (self , cuda_graph , tp_size , pp_size , ep_size ):
499+ with LLM (
500+ f"{ llm_models_root ()} /llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8" ,
501+ tensor_parallel_size = tp_size ,
502+ # Keep this low to avoid warmup OOM in CI
503+ max_seq_len = 8192 ,
504+ pipeline_parallel_size = pp_size ,
505+ moe_expert_parallel_size = ep_size ,
506+ use_cuda_graph = cuda_graph ) as llm :
507+ assert llm .args .quant_config .quant_algo == QuantAlgo .FP8
508+ assert llm .args .quant_config .kv_cache_quant_algo == QuantAlgo .FP8
509+ task = MMLU (self .MODEL_NAME )
510+ task .evaluate (llm )
511+ task = GSM8K (self .MODEL_NAME )
512+ task .evaluate (llm )
513+
514+ @skip_pre_hopper
515+ @pytest .mark .skip_less_mpi_world_size (8 )
516+ @parametrize_with_ids ("cuda_graph" , [False , True ])
517+ @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 8 )],
518+ ids = ["tp8ep8" ])
519+ def test_fp8_chunked_prefill (self , cuda_graph , tp_size , pp_size , ep_size ):
520+ with LLM (
521+ f"{ llm_models_root ()} /llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8" ,
522+ tensor_parallel_size = tp_size ,
523+ # Keep this low to avoid warmup OOM in CI
524+ max_seq_len = 8192 ,
525+ pipeline_parallel_size = pp_size ,
526+ moe_expert_parallel_size = ep_size ,
527+ enable_chunked_prefill = True ,
528+ max_num_tokens = 256 ,
529+ use_cuda_graph = cuda_graph ) as llm :
530+ assert llm .args .quant_config .quant_algo == QuantAlgo .FP8
531+ assert llm .args .quant_config .kv_cache_quant_algo == QuantAlgo .FP8
532+ task = MMLU (self .MODEL_NAME )
533+ task .evaluate (llm )
534+ task = GSM8K (self .MODEL_NAME )
535+ task .evaluate (llm )
536+
492537 @skip_pre_hopper
493538 @pytest .mark .skip_less_mpi_world_size (8 )
494539 @parametrize_with_ids ("torch_compile" , [True , False ])
@@ -587,6 +632,50 @@ def test_fp4(self, cuda_graph, tp_size, pp_size, ep_size):
587632 task = GSM8K (self .MODEL_NAME )
588633 task .evaluate (llm )
589634
635+ @skip_pre_hopper
636+ @pytest .mark .skip_less_mpi_world_size (4 )
637+ @parametrize_with_ids ("cuda_graph" , [True ])
638+ @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(4 , 1 , 4 )],
639+ ids = ["tp4ep4" ])
640+ def test_fp8_chunked_prefill (self , cuda_graph , tp_size , pp_size , ep_size ):
641+ with LLM (
642+ f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8" ,
643+ tensor_parallel_size = tp_size ,
644+ max_seq_len = 22000 ,
645+ pipeline_parallel_size = pp_size ,
646+ moe_expert_parallel_size = ep_size ,
647+ enable_chunked_prefill = True ,
648+ max_num_tokens = 256 ,
649+ use_cuda_graph = cuda_graph ) as llm :
650+ assert llm .args .quant_config .quant_algo == QuantAlgo .FP8
651+ assert llm .args .quant_config .kv_cache_quant_algo == QuantAlgo .FP8
652+ task = MMLU (self .MODEL_NAME )
653+ task .evaluate (llm )
654+ task = GSM8K (self .MODEL_NAME )
655+ task .evaluate (llm )
656+
657+ @skip_pre_blackwell
658+ @pytest .mark .skip_less_mpi_world_size (8 )
659+ @parametrize_with_ids ("cuda_graph" , [True ])
660+ @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(4 , 1 , 4 )],
661+ ids = ["tp4ep4" ])
662+ def test_fp4_chunked_prefill (self , cuda_graph , tp_size , pp_size , ep_size ):
663+ with LLM (
664+ f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4" ,
665+ tensor_parallel_size = tp_size ,
666+ pipeline_parallel_size = pp_size ,
667+ moe_expert_parallel_size = ep_size ,
668+ max_seq_len = 22000 ,
669+ enable_chunked_prefill = True ,
670+ max_num_tokens = 256 ,
671+ use_cuda_graph = cuda_graph ) as llm :
672+ assert llm .args .quant_config .quant_algo == QuantAlgo .NVFP4
673+ assert llm .args .quant_config .kv_cache_quant_algo == QuantAlgo .FP8
674+ task = MMLU (self .MODEL_NAME )
675+ task .evaluate (llm )
676+ task = GSM8K (self .MODEL_NAME )
677+ task .evaluate (llm )
678+
590679
591680class TestMistral7B (LlmapiAccuracyTestHarness ):
592681 MODEL_NAME = "mistralai/Mistral-7B-v0.1"
0 commit comments