@@ -537,6 +537,7 @@ async def benchmark(
537
537
ignore_eos : bool ,
538
538
goodput_config_dict : Dict [str , float ],
539
539
max_concurrency : Optional [int ],
540
+ lora_modules : Optional [List [str ]],
540
541
):
541
542
if backend in ASYNC_REQUEST_FUNCS :
542
543
request_func = ASYNC_REQUEST_FUNCS [backend ]
@@ -562,6 +563,7 @@ async def benchmark(
562
563
multi_modal_content = test_mm_content ,
563
564
ignore_eos = ignore_eos ,
564
565
)
566
+
565
567
test_output = await request_func (request_func_input = test_input )
566
568
if not test_output .success :
567
569
raise ValueError (
@@ -570,6 +572,11 @@ async def benchmark(
570
572
else :
571
573
print ("Initial test run completed. Starting main benchmark run..." )
572
574
575
+ if lora_modules :
576
+ # For each input request, choose a LoRA module at random.
577
+ lora_modules = iter (
578
+ [random .choice (lora_modules ) for _ in range (len (input_requests ))])
579
+
573
580
if profile :
574
581
print ("Starting profiler..." )
575
582
profile_input = RequestFuncInput (model = model_id ,
@@ -616,8 +623,13 @@ async def limited_request_func(request_func_input, pbar):
616
623
tasks : List [asyncio .Task ] = []
617
624
async for request in get_request (input_requests , request_rate , burstiness ):
618
625
prompt , prompt_len , output_len , mm_content = request
619
- request_func_input = RequestFuncInput (model = model_id ,
620
- model_name = model_name ,
626
+ req_model_id , req_model_name = model_id , model_name
627
+ if lora_modules :
628
+ req_lora_module = next (lora_modules )
629
+ req_model_id , req_model_name = req_lora_module , req_lora_module
630
+
631
+ request_func_input = RequestFuncInput (model = req_model_id ,
632
+ model_name = req_model_name ,
621
633
prompt = prompt ,
622
634
api_url = api_url ,
623
635
prompt_len = prompt_len ,
@@ -900,6 +912,7 @@ def main(args: argparse.Namespace):
900
912
ignore_eos = args .ignore_eos ,
901
913
goodput_config_dict = goodput_config_dict ,
902
914
max_concurrency = args .max_concurrency ,
915
+ lora_modules = args .lora_modules ,
903
916
))
904
917
905
918
# Save config and results to json
@@ -1237,5 +1250,12 @@ def main(args: argparse.Namespace):
1237
1250
"If not specified, the model name will be the "
1238
1251
"same as the ``--model`` argument. " )
1239
1252
1253
+ parser .add_argument ("--lora-modules" ,
1254
+ nargs = '+' ,
1255
+ default = None ,
1256
+ help = "A subset of LoRA module names passed in when "
1257
+ "launching the server. For each request, the "
1258
+ "script chooses a LoRA module at random." )
1259
+
1240
1260
args = parser .parse_args ()
1241
1261
main (args )
0 commit comments