2727                                 SamplingParams , TorchCompileConfig )
2828from  tensorrt_llm .quantization  import  QuantAlgo 
2929
30- from  ..conftest  import  (llm_models_root , parametrize_with_ids , skip_no_hopper ,
30+ from  ..conftest  import  (get_device_count , get_device_memory , llm_models_root ,
31+                         parametrize_with_ids , skip_no_hopper ,
3132                        skip_post_blackwell , skip_pre_ada , skip_pre_blackwell ,
3233                        skip_pre_hopper )
3334from  .accuracy_core  import  (GSM8K , MMLU , CnnDailymail , GPQADiamond ,
@@ -509,19 +510,26 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
509510    MODEL_PATH  =  f"{ llm_models_root ()}  
510511
511512    @skip_pre_blackwell  
512-     @pytest .mark .skip_less_mpi_world_size (8 ) 
513513    @parametrize_with_ids ("cuda_graph" , [False , True ]) 
514-     @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), 
515-                                                          (8 , 1 , 8 )], 
516-                              ids = ["tp8" , "tp8ep4" , "tp8ep8" ]) 
514+     @pytest .mark .parametrize ( 
515+         "tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), (8 , 1 , 8 ), (4 , 1 , 1 ), 
516+                                     (4 , 1 , 2 ), (4 , 1 , 4 )], 
517+         ids = ["tp8" , "tp8ep4" , "tp8ep8" , "tp4" , "tp4ep2" , "tp4ep4" ]) 
517518    def  test_auto_dtype (self , cuda_graph , tp_size , pp_size , ep_size ):
519+         if  get_device_memory () <  270000  and  get_device_count () <  8 :
520+             pytest .skip ("Not enough memory for this test" )
521+         if  get_device_count () !=  tp_size  *  pp_size :
522+             pytest .skip ("Device count mismatch with world size" )
523+ 
524+         kv_cache_config  =  KvCacheConfig (free_gpu_memory_fraction = 0.8 )
518525        with  LLM (
519526                self .MODEL_PATH ,
520527                tensor_parallel_size = tp_size ,
521528                # Keep this low to avoid warmup OOM in CI 
522529                max_seq_len = 8192 ,
523530                pipeline_parallel_size = pp_size ,
524531                moe_expert_parallel_size = ep_size ,
532+                 kv_cache_config = kv_cache_config ,
525533                cuda_graph_config = CudaGraphConfig ()
526534                if  cuda_graph  else  None ) as  llm :
527535            task  =  MMLU (self .MODEL_NAME )
@@ -547,20 +555,27 @@ def test_chunked_prefill(self, attn_backend):
547555            task .evaluate (llm )
548556
549557    @skip_pre_hopper  
550-     @pytest .mark .skip_less_mpi_world_size ( 8 ) 
558+     @pytest .mark .skip_less_device_memory ( 80000 ) 
551559    @parametrize_with_ids ("cuda_graph" , [False , True ]) 
552-     @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), 
553-                                                          (8 , 1 , 8 )], 
554-                              ids = ["tp8" , "tp8ep4" , "tp8ep8" ]) 
560+     @pytest .mark .parametrize ( 
561+         "tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), (8 , 1 , 8 ), (4 , 1 , 1 ), 
562+                                     (4 , 1 , 2 ), (4 , 1 , 4 )], 
563+         ids = ["tp8" , "tp8ep4" , "tp8ep8" , "tp4" , "tp4ep2" , "tp4ep4" ]) 
555564    def  test_fp8 (self , cuda_graph , tp_size , pp_size , ep_size ):
565+         if  get_device_memory () <  140000  and  get_device_count () <  8 :
566+             pytest .skip ("Not enough memory for this test" )
567+         if  get_device_count () !=  tp_size  *  pp_size :
568+             pytest .skip ("Device count mismatch with world size" )
569+ 
556570        with  LLM (
557571                f"{ llm_models_root ()}  ,
558572                tensor_parallel_size = tp_size ,
559573                # Keep this low to avoid warmup OOM in CI 
560574                max_seq_len = 8192 ,
561575                pipeline_parallel_size = pp_size ,
562576                moe_expert_parallel_size = ep_size ,
563-                 use_cuda_graph = cuda_graph ) as  llm :
577+                 cuda_graph_config = CudaGraphConfig ()
578+                 if  cuda_graph  else  None ) as  llm :
564579            assert  llm .args .quant_config .quant_algo  ==  QuantAlgo .FP8 
565580            assert  llm .args .quant_config .kv_cache_quant_algo  ==  QuantAlgo .FP8 
566581            task  =  MMLU (self .MODEL_NAME )
@@ -583,7 +598,8 @@ def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
583598                moe_expert_parallel_size = ep_size ,
584599                enable_chunked_prefill = True ,
585600                max_num_tokens = 256 ,
586-                 use_cuda_graph = cuda_graph ) as  llm :
601+                 cuda_graph_config = CudaGraphConfig ()
602+                 if  cuda_graph  else  None ) as  llm :
587603            assert  llm .args .quant_config .quant_algo  ==  QuantAlgo .FP8 
588604            assert  llm .args .quant_config .kv_cache_quant_algo  ==  QuantAlgo .FP8 
589605            task  =  MMLU (self .MODEL_NAME )
@@ -622,16 +638,21 @@ def test_fp8_eagle3(self, tp_size, pp_size, ep_size, torch_compile):
622638            task .evaluate (llm )
623639
624640
641+ @pytest .mark .skip_less_device_memory (80000 ) 
642+ @pytest .mark .skip_less_host_memory (100000 ) 
625643class  TestLlama4ScoutInstruct (LlmapiAccuracyTestHarness ):
626644    MODEL_NAME  =  "meta-llama/Llama-4-Scout-17B-16E-Instruct" 
627645
628646    @skip_pre_hopper  
629-     @pytest .mark .skip_less_mpi_world_size (8 ) 
630647    @parametrize_with_ids ("cuda_graph" , [False , True ]) 
631-     @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), 
632-                                                          (8 , 1 , 8 )], 
633-                              ids = ["tp8" , "tp8ep4" , "tp8ep8" ]) 
648+     @pytest .mark .parametrize ( 
649+         "tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ), (8 , 1 , 8 ), (4 , 1 , 1 ), 
650+                                     (4 , 1 , 2 ), (4 , 1 , 4 )], 
651+         ids = ["tp8" , "tp8ep4" , "tp8ep8" , "tp4" , "tp4ep2" , "tp4ep4" ]) 
634652    def  test_auto_dtype (self , cuda_graph , tp_size , pp_size , ep_size ):
653+         if  get_device_count () !=  tp_size  *  pp_size :
654+             pytest .skip ("Device count mismatch with world size" )
655+ 
635656        model_path  =  f"{ llm_models_root ()}  
636657        with  LLM (
637658                model_path ,
@@ -648,11 +669,13 @@ def test_auto_dtype(self, cuda_graph, tp_size, pp_size, ep_size):
648669            task .evaluate (llm )
649670
650671    @skip_pre_hopper  
651-     @pytest .mark .skip_less_mpi_world_size (8 ) 
652672    @parametrize_with_ids ("cuda_graph" , [True ]) 
653673    @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 8 ), (4 , 1 , 1 )], 
654674                             ids = ["tp8ep8" , "tp4" ]) 
655675    def  test_fp8 (self , cuda_graph , tp_size , pp_size , ep_size ):
676+         if  get_device_count () !=  tp_size  *  pp_size :
677+             pytest .skip ("Device count mismatch with world size" )
678+ 
656679        model_path  =  f"{ llm_models_root ()}  
657680        with  LLM (
658681                model_path ,
@@ -661,6 +684,7 @@ def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size):
661684                max_seq_len = 8192 ,
662685                pipeline_parallel_size = pp_size ,
663686                moe_expert_parallel_size = ep_size ,
687+                 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.8 ),
664688                cuda_graph_config = CudaGraphConfig ()
665689                if  cuda_graph  else  None ) as  llm :
666690            assert  llm .args .quant_config .quant_algo  ==  QuantAlgo .FP8 
@@ -670,11 +694,13 @@ def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size):
670694            task .evaluate (llm )
671695
672696    @skip_pre_blackwell  
673-     @pytest .mark .skip_less_mpi_world_size (8 ) 
674697    @parametrize_with_ids ("cuda_graph" , [True ]) 
675698    @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 8 ), (4 , 1 , 1 )], 
676699                             ids = ["tp8ep8" , "tp4" ]) 
677700    def  test_fp4 (self , cuda_graph , tp_size , pp_size , ep_size ):
701+         if  get_device_count () !=  tp_size  *  pp_size :
702+             pytest .skip ("Device count mismatch with world size" )
703+ 
678704        model_path  =  f"{ llm_models_root ()}  
679705        with  LLM (
680706                model_path ,
@@ -706,7 +732,8 @@ def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
706732                moe_expert_parallel_size = ep_size ,
707733                enable_chunked_prefill = True ,
708734                max_num_tokens = 256 ,
709-                 use_cuda_graph = cuda_graph ) as  llm :
735+                 cuda_graph_config = CudaGraphConfig ()
736+                 if  cuda_graph  else  None ) as  llm :
710737            assert  llm .args .quant_config .quant_algo  ==  QuantAlgo .FP8 
711738            assert  llm .args .quant_config .kv_cache_quant_algo  ==  QuantAlgo .FP8 
712739            task  =  MMLU (self .MODEL_NAME )
@@ -715,7 +742,7 @@ def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
715742            task .evaluate (llm )
716743
717744    @skip_pre_blackwell  
718-     @pytest .mark .skip_less_mpi_world_size (8 ) 
745+     @pytest .mark .skip_less_mpi_world_size (4 ) 
719746    @parametrize_with_ids ("cuda_graph" , [True ]) 
720747    @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(4 , 1 , 4 )], 
721748                             ids = ["tp4ep4" ]) 
@@ -728,7 +755,8 @@ def test_fp4_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
728755                max_seq_len = 22000 ,
729756                enable_chunked_prefill = True ,
730757                max_num_tokens = 256 ,
731-                 use_cuda_graph = cuda_graph ) as  llm :
758+                 cuda_graph_config = CudaGraphConfig ()
759+                 if  cuda_graph  else  None ) as  llm :
732760            assert  llm .args .quant_config .quant_algo  ==  QuantAlgo .NVFP4 
733761            assert  llm .args .quant_config .kv_cache_quant_algo  ==  QuantAlgo .FP8 
734762            task  =  MMLU (self .MODEL_NAME )
0 commit comments