@@ -1180,17 +1180,10 @@ def _executor_loop_overlap(self):
11801180 torch .cuda .set_device (self .device_id )
11811181 # ensure the context is created, otherwise, some MPI calls will fail.
11821182 CUASSERT (cudart .cudaSetDevice (self .device_id ))
1183- if self .dist .rank == 0 and not self .is_warmup and self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver :
1184- while self .executor_request_queue .get_request_queue_size (
1185- ) < self .benchmark_req_queues_size :
1186- logger .info (
1187- f"sleep 5 seconds, num_request_queue: { self .executor_request_queue .get_request_queue_size ()} "
1188- )
1189- time .sleep (5 )
1190-
11911183 with self ._profiler () as profile_step :
11921184 iter_start_time = time .time ()
11931185 iter_stats = None
1186+ can_forward = False if self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver else True
11941187 while True :
11951188 profile_step ()
11961189 if self .enable_iter_perf_stats :
@@ -1199,6 +1192,36 @@ def _executor_loop_overlap(self):
11991192 scheduled_batch , iter_stats = self ._prepare_and_schedule_batch ()
12001193 if scheduled_batch is None :
12011194 break
1195+ # In gen-only benchmarking mode, wait until the number of scheduled generation
1196+ # requests reaches the required threshold before starting forward pass,
1197+ # to ensure consistent batch sizes for accurate performance measurement.
1198+ if not self .is_warmup and not can_forward :
1199+ if self .enable_attention_dp :
1200+ local_can_forward = self .executor_request_queue .num_fetch_requests + \
1201+ len (scheduled_batch .generation_requests ) >= self .benchmark_req_queues_size
1202+ all_can_forward = self .dist .tp_allgather (
1203+ local_can_forward )
1204+ if all (all_can_forward ):
1205+ can_forward = True
1206+ time .sleep (10 )
1207+ else :
1208+ if self .dist .rank == 0 :
1209+ logger .info (
1210+ f"sleep 10 seconds, num_fetched_requests: { self .executor_request_queue .num_fetch_requests } , scheduled_gen_batch: { len (scheduled_batch .generation_requests )} "
1211+ )
1212+ time .sleep (10 )
1213+ continue
1214+ else :
1215+ if len (scheduled_batch .generation_requests
1216+ ) < self .benchmark_req_queues_size :
1217+ if self .dist .rank == 0 :
1218+ logger .info (
1219+ f"sleep 10 seconds, scheduled_gen_batch: { len (scheduled_batch .generation_requests )} "
1220+ )
1221+ time .sleep (10 )
1222+ continue
1223+ else :
1224+ can_forward = True
12021225
12031226 self ._pause_requests (scheduled_batch .paused_requests )
12041227
0 commit comments