Skip to content

Commit c8f98b3

Browse files
authored
[None] [feat] Update disagg gen-only benchmark. (#7917)
Signed-off-by: Xianjie <[email protected]>
1 parent 3328235 commit c8f98b3

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,17 +1180,10 @@ def _executor_loop_overlap(self):
11801180
torch.cuda.set_device(self.device_id)
11811181
# ensure the context is created, otherwise, some MPI calls will fail.
11821182
CUASSERT(cudart.cudaSetDevice(self.device_id))
1183-
if self.dist.rank == 0 and not self.is_warmup and self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver:
1184-
while self.executor_request_queue.get_request_queue_size(
1185-
) < self.benchmark_req_queues_size:
1186-
logger.info(
1187-
f"sleep 5 seconds, num_request_queue: {self.executor_request_queue.get_request_queue_size()}"
1188-
)
1189-
time.sleep(5)
1190-
11911183
with self._profiler() as profile_step:
11921184
iter_start_time = time.time()
11931185
iter_stats = None
1186+
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
11941187
while True:
11951188
profile_step()
11961189
if self.enable_iter_perf_stats:
@@ -1199,6 +1192,36 @@ def _executor_loop_overlap(self):
11991192
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
12001193
if scheduled_batch is None:
12011194
break
1195+
# In gen-only benchmarking mode, wait until the number of scheduled generation
1196+
# requests reaches the required threshold before starting forward pass,
1197+
# to ensure consistent batch sizes for accurate performance measurement.
1198+
if not self.is_warmup and not can_forward:
1199+
if self.enable_attention_dp:
1200+
local_can_forward = self.executor_request_queue.num_fetch_requests + \
1201+
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
1202+
all_can_forward = self.dist.tp_allgather(
1203+
local_can_forward)
1204+
if all(all_can_forward):
1205+
can_forward = True
1206+
time.sleep(10)
1207+
else:
1208+
if self.dist.rank == 0:
1209+
logger.info(
1210+
f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
1211+
)
1212+
time.sleep(10)
1213+
continue
1214+
else:
1215+
if len(scheduled_batch.generation_requests
1216+
) < self.benchmark_req_queues_size:
1217+
if self.dist.rank == 0:
1218+
logger.info(
1219+
f"sleep 10 seconds, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
1220+
)
1221+
time.sleep(10)
1222+
continue
1223+
else:
1224+
can_forward = True
12021225

12031226
self._pause_requests(scheduled_batch.paused_requests)
12041227

0 commit comments

Comments
 (0)