Skip to content

Commit 953f4fd

Browse files
authored
[None][fix] acceptance rate calculation fix in benchmark_serving (#6746)
Signed-off-by: Zero Zeng <[email protected]>
1 parent 2c86cee commit 953f4fd

File tree

10 files changed

+94
-81
lines changed

10 files changed

+94
-81
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,10 @@ struct Result
860860
/// one token can be generated per iteration. Used for speculative decoding statistics.
861861
SizeType32 decodingIter{0};
862862

863+
/// @brief The average number of decoded tokens per iteration. For standard model it is 1.
864+
/// For speculative decoding model >= 1 -- number of draft tokens accepted per step + 1.
865+
float avgDecodedTokensPerIter{0.0f};
866+
863867
/// @brief The index of the output sequence of this result where 0 <= sequenceIndex < numReturnSequences.
864868
/// In beam search (beamWidth > 1), this index will be always zero because all beams to be returned are included
865869
/// in this result.

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
200200

201201
result.finishReasons = sliceBeams(mFinishReasons);
202202
result.decodingIter = mDecodingIter;
203+
result.avgDecodedTokensPerIter = getAvgDecodedTokensPerIter();
203204

204205
if (hasAdditionalOutputs())
205206
{

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ Result Serialization::deserializeResult(std::istream& is)
895895
result.finishReasons = su::deserialize<std::vector<FinishReason>>(is);
896896
result.contextPhaseParams = su::deserialize<std::optional<ContextPhaseParams>>(is);
897897
result.decodingIter = su::deserialize<SizeType32>(is);
898+
result.avgDecodedTokensPerIter = su::deserialize<float>(is);
898899
result.sequenceIndex = su::deserialize<SizeType32>(is);
899900
result.isSequenceFinal = su::deserialize<bool>(is);
900901
result.requestPerfMetrics = su::deserialize<std::optional<RequestPerfMetrics>>(is);
@@ -915,6 +916,7 @@ void Serialization::serialize(Result const& result, std::ostream& os)
915916
su::serialize(result.finishReasons, os);
916917
su::serialize(result.contextPhaseParams, os);
917918
su::serialize(result.decodingIter, os);
919+
su::serialize(result.avgDecodedTokensPerIter, os);
918920
su::serialize(result.sequenceIndex, os);
919921
su::serialize(result.isSequenceFinal, os);
920922
su::serialize(result.requestPerfMetrics, os);
@@ -935,6 +937,7 @@ size_t Serialization::serializedSize(Result const& result)
935937
totalSize += su::serializedSize(result.finishReasons);
936938
totalSize += su::serializedSize(result.contextPhaseParams);
937939
totalSize += su::serializedSize(result.decodingIter);
940+
totalSize += su::serializedSize(result.avgDecodedTokensPerIter);
938941
totalSize += su::serializedSize(result.sequenceIndex);
939942
totalSize += su::serializedSize(result.isSequenceFinal);
940943
totalSize += su::serializedSize(result.requestPerfMetrics);

cpp/tensorrt_llm/nanobind/executor/request.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ void initRequestBindings(nb::module_& m)
851851

852852
auto resultSetstate = [](tle::Result& self, nb::tuple const& state)
853853
{
854-
if (state.size() != 13)
854+
if (state.size() != 14)
855855
{
856856
throw std::runtime_error("Invalid Request state!");
857857
}
@@ -867,16 +867,17 @@ void initRequestBindings(nb::module_& m)
867867
result.sequenceIndex = nb::cast<SizeType32>(state[8]);
868868
result.isSequenceFinal = nb::cast<bool>(state[9]);
869869
result.decodingIter = nb::cast<SizeType32>(state[10]);
870-
result.contextPhaseParams = nb::cast<std::optional<tle::ContextPhaseParams>>(state[11]);
871-
result.requestPerfMetrics = nb::cast<std::optional<tle::RequestPerfMetrics>>(state[12]);
870+
result.avgDecodedTokensPerIter = nb::cast<float>(state[11]);
871+
result.contextPhaseParams = nb::cast<std::optional<tle::ContextPhaseParams>>(state[12]);
872+
result.requestPerfMetrics = nb::cast<std::optional<tle::RequestPerfMetrics>>(state[13]);
872873
new (&self) tle::Result(result);
873874
};
874875

875876
auto resultGetstate = [](tle::Result const& self)
876877
{
877878
return nb::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits,
878879
self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal,
879-
self.decodingIter, self.contextPhaseParams, self.requestPerfMetrics);
880+
self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics);
880881
};
881882

882883
nb::class_<tle::Result>(m, "Result")
@@ -893,6 +894,7 @@ void initRequestBindings(nb::module_& m)
893894
.def_rw("sequence_index", &tle::Result::sequenceIndex)
894895
.def_rw("is_sequence_final", &tle::Result::isSequenceFinal)
895896
.def_rw("decoding_iter", &tle::Result::decodingIter)
897+
.def_rw("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter)
896898
.def_rw("context_phase_params", &tle::Result::contextPhaseParams)
897899
.def_rw("request_perf_metrics", &tle::Result::requestPerfMetrics)
898900
.def_rw("additional_outputs", &tle::Result::additionalOutputs)

cpp/tensorrt_llm/pybind/executor/request.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ void initRequestBindings(pybind11::module_& m)
795795

796796
auto resultSetstate = [](py::tuple const& state)
797797
{
798-
if (state.size() != 13)
798+
if (state.size() != 14)
799799
{
800800
throw std::runtime_error("Invalid Request state!");
801801
}
@@ -811,16 +811,17 @@ void initRequestBindings(pybind11::module_& m)
811811
result.sequenceIndex = state[8].cast<SizeType32>();
812812
result.isSequenceFinal = state[9].cast<bool>();
813813
result.decodingIter = state[10].cast<SizeType32>();
814-
result.contextPhaseParams = state[11].cast<std::optional<tle::ContextPhaseParams>>();
815-
result.requestPerfMetrics = state[12].cast<std::optional<tle::RequestPerfMetrics>>();
814+
result.avgDecodedTokensPerIter = state[11].cast<float>();
815+
result.contextPhaseParams = state[12].cast<std::optional<tle::ContextPhaseParams>>();
816+
result.requestPerfMetrics = state[13].cast<std::optional<tle::RequestPerfMetrics>>();
816817
return std::make_unique<tle::Result>(result);
817818
};
818819

819820
auto resultGetstate = [](tle::Result const& self)
820821
{
821822
return py::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits,
822823
self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal,
823-
self.decodingIter, self.contextPhaseParams, self.requestPerfMetrics);
824+
self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics);
824825
};
825826

826827
py::class_<tle::Result>(m, "Result")
@@ -837,6 +838,7 @@ void initRequestBindings(pybind11::module_& m)
837838
.def_readwrite("sequence_index", &tle::Result::sequenceIndex)
838839
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal)
839840
.def_readwrite("decoding_iter", &tle::Result::decodingIter)
841+
.def_readwrite("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter)
840842
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
841843
.def_readwrite("request_perf_metrics", &tle::Result::requestPerfMetrics)
842844
.def_readwrite("additional_outputs", &tle::Result::additionalOutputs)

tensorrt_llm/executor/result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def __init__(self,
158158
self.postproc_params = postproc_params
159159
self.disaggregated_params = None
160160
self.decoding_iter = 0
161+
# Average decoded tokens per runtime iteration; set when the first LLM response arrives.
162+
# None indicates not yet available (e.g., before first step/stream).
163+
self.avg_decoded_tokens_per_iter: Optional[float] = None
161164
self._done = False
162165
self.metrics_dict = {}
163166

@@ -331,6 +334,7 @@ def _handle_response(self,
331334
self._done = response_result.is_final
332335
context_phase_params = response_result.context_phase_params
333336
self.decoding_iter = response_result.decoding_iter
337+
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
334338
if context_phase_params is not None:
335339
self.disaggregated_params = DisaggregatedParams(
336340
request_type="context_only",

tensorrt_llm/serve/openai_protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
128128
"including encountering the EOS token"),
129129
)
130130
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
131+
avg_decoded_tokens_per_iter: Optional[float] = Field(default=None)
131132

132133

133134
class CompletionResponse(OpenAIBaseModel):
@@ -155,6 +156,7 @@ class CompletionResponseStreamChoice(OpenAIBaseModel):
155156
"to stop, None if the completion finished for some other reason "
156157
"including encountering the EOS token"),
157158
)
159+
avg_decoded_tokens_per_iter: Optional[float] = Field(default=None)
158160

159161

160162
class CompletionStreamResponse(OpenAIBaseModel):
@@ -392,6 +394,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
392394
stop_reason: Optional[Union[int, str]] = None
393395

394396
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
397+
avg_decoded_tokens_per_iter: Optional[float] = Field(default=None)
395398

396399

397400
class ChatCompletionResponse(OpenAIBaseModel):
@@ -419,6 +422,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
419422
logprobs: Optional[ChatCompletionLogProbs] = None
420423
finish_reason: Optional[str] = None
421424
stop_reason: Optional[Union[int, str]] = None
425+
avg_decoded_tokens_per_iter: Optional[float] = Field(default=None)
422426

423427

424428
class ChatCompletionStreamResponse(OpenAIBaseModel):

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def yield_first_chat(num_tokens: int,
160160

161161
choice = ChatCompletionResponseStreamChoice(index=i,
162162
delta=delta_message,
163-
finish_reason=None)
163+
finish_reason=None,
164+
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
164165
if args.return_logprobs:
165166
logprobs = output.logprobs_diff
166167
token_ids = output.token_ids_diff
@@ -224,6 +225,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
224225
finish_reason=output.finish_reason,
225226
stop_reason=output.stop_reason,
226227
disaggregated_params=disaggregated_params,
228+
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
227229
)
228230

229231
if args.return_logprobs:
@@ -293,6 +295,7 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
293295
token_ids=None if args.detokenize else output.token_ids_diff,
294296
finish_reason = output.finish_reason,
295297
stop_reason = output.stop_reason,
298+
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
296299
)
297300
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
298301
if include_continuous_usage:
@@ -337,6 +340,7 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
337340
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
338341
stop_reason=output.stop_reason,
339342
finish_reason=output.finish_reason,
343+
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
340344
)
341345

342346
completion_tokens += output.length

tensorrt_llm/serve/scripts/backend_request_func.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class RequestFuncOutput:
4545
tpot: float = 0.0 # avg next-token latencies
4646
prompt_len: int = 0
4747
error: str = ""
48-
decode_iteration: int = 0 # Number of decoding iterations
48+
avg_decoded_tokens_per_iter: float = 0.0 # Average tokens decoded per iteration
4949

5050

5151
async def async_request_trt_llm(
@@ -82,7 +82,6 @@ async def async_request_trt_llm(
8282
ttft = 0.0
8383
st = time.perf_counter()
8484
most_recent_timestamp = st
85-
decode_iteration_count = 0 # Track decoding iterations
8685
try:
8786
async with request_session.post(url=api_url, json=payload) as response:
8887
if response.status == 200:
@@ -108,22 +107,27 @@ async def async_request_trt_llm(
108107
else:
109108
output.itl.append(timestamp - most_recent_timestamp)
110109

111-
# Increment decode iteration for each chunk
112-
decode_iteration_count += 1
113110
most_recent_timestamp = timestamp
114111

112+
# Extract avg_decoded_tokens_per_iter from TensorRT-LLM response
113+
if "avg_decoded_tokens_per_iter" in data:
114+
output.avg_decoded_tokens_per_iter = data[
115+
"avg_decoded_tokens_per_iter"]
116+
115117
output.latency = most_recent_timestamp - st
116-
output.decode_iteration = decode_iteration_count
118+
117119
else:
118120
content = await response.content.read()
119121
data = json.loads(content.decode())
120122
output.ttft = -1
121123
output.itl = []
122124
output.generated_text = data["text_output"]
123125
output.latency = time.perf_counter() - st
124-
# For non-streaming, estimate decode_iteration as number of output tokens
125-
output.decode_iteration = len(output.generated_text.split(
126-
)) if output.generated_text else 1
126+
127+
# Extract avg_decoded_tokens_per_iter from non-streaming TensorRT-LLM response
128+
if "avg_decoded_tokens_per_iter" in data:
129+
output.avg_decoded_tokens_per_iter = data[
130+
"avg_decoded_tokens_per_iter"]
127131

128132
else:
129133
output.error = response.reason or ""
@@ -138,6 +142,7 @@ async def async_request_trt_llm(
138142

139143
if pbar:
140144
pbar.update(1)
145+
141146
return output
142147

143148

@@ -183,7 +188,6 @@ async def async_request_openai_completions(
183188
generated_text = ""
184189
st = time.perf_counter()
185190
most_recent_timestamp = st
186-
decode_iteration_count = 0 # Track decoding iterations
187191
try:
188192
async with request_session.post(url=api_url,
189193
json=payload,
@@ -220,11 +224,13 @@ async def async_request_openai_completions(
220224
output.itl.append(timestamp -
221225
most_recent_timestamp)
222226

223-
# Increment decode iteration for each chunk with text
224-
if text is not None:
225-
decode_iteration_count += 1
226227
most_recent_timestamp = timestamp
227228
generated_text += text or ""
229+
230+
# Extract avg_decoded_tokens_per_iter from streaming response
231+
if "avg_decoded_tokens_per_iter" in choices[0]:
232+
output.avg_decoded_tokens_per_iter = choices[
233+
0]["avg_decoded_tokens_per_iter"]
228234
elif usage := data.get("usage"):
229235
output.output_tokens = usage.get(
230236
"completion_tokens")
@@ -237,7 +243,6 @@ async def async_request_openai_completions(
237243
"This response will be marked as failed!")
238244
output.generated_text = generated_text
239245
output.latency = most_recent_timestamp - st
240-
output.decode_iteration = decode_iteration_count
241246
else:
242247
content = await response.content.read()
243248
data = json.loads(content.decode())
@@ -248,8 +253,11 @@ async def async_request_openai_completions(
248253
output.ttft = -1
249254
output.itl = []
250255
output.output_tokens = data["usage"]["completion_tokens"]
251-
# For non-streaming, estimate decode_iteration as number of output tokens
252-
output.decode_iteration = output.output_tokens if output.output_tokens > 0 else 1
256+
# Extract avg_decoded_tokens_per_iter if available
257+
choice = data["choices"][0]
258+
if "avg_decoded_tokens_per_iter" in choice:
259+
output.avg_decoded_tokens_per_iter = choice[
260+
"avg_decoded_tokens_per_iter"]
253261
else:
254262
output.error = response.reason or ""
255263
output.success = False
@@ -263,6 +271,7 @@ async def async_request_openai_completions(
263271

264272
if pbar:
265273
pbar.update(1)
274+
266275
return output
267276

268277

@@ -322,7 +331,6 @@ async def async_request_openai_chat_completions(
322331
ttft = 0.0
323332
st = time.perf_counter()
324333
most_recent_timestamp = st
325-
decode_iteration_count = 0 # Track decoding iterations
326334
try:
327335
async with request_session.post(url=api_url,
328336
json=payload,
@@ -353,10 +361,12 @@ async def async_request_openai_chat_completions(
353361
output.itl.append(timestamp -
354362
most_recent_timestamp)
355363

356-
# Increment decode iteration for each chunk with content
357-
if content is not None:
358-
decode_iteration_count += 1
359364
generated_text += content or ""
365+
366+
# Extract avg_decoded_tokens_per_iter from streaming chat response
367+
if "avg_decoded_tokens_per_iter" in choices[0]:
368+
output.avg_decoded_tokens_per_iter = choices[
369+
0]["avg_decoded_tokens_per_iter"]
360370
elif usage := data.get("usage"):
361371
output.output_tokens = usage.get(
362372
"completion_tokens")
@@ -365,7 +375,6 @@ async def async_request_openai_chat_completions(
365375

366376
output.generated_text = generated_text
367377
output.latency = most_recent_timestamp - st
368-
output.decode_iteration = decode_iteration_count
369378
else:
370379
content = await response.content.read()
371380
data = json.loads(content.decode())
@@ -375,8 +384,12 @@ async def async_request_openai_chat_completions(
375384
output.itl = []
376385
output.latency = time.perf_counter() - st
377386
output.ttft = -1
378-
# For non-streaming, estimate decode_iteration as number of output tokens
379-
output.decode_iteration = output.output_tokens if output.output_tokens > 0 else 1
387+
388+
# Extract avg_decoded_tokens_per_iter if available
389+
choice = data["choices"][0]
390+
if "avg_decoded_tokens_per_iter" in choice:
391+
output.avg_decoded_tokens_per_iter = choice[
392+
"avg_decoded_tokens_per_iter"]
380393

381394
else:
382395
output.error = response.reason or ""
@@ -391,6 +404,7 @@ async def async_request_openai_chat_completions(
391404

392405
if pbar:
393406
pbar.update(1)
407+
394408
return output
395409

396410

0 commit comments

Comments
 (0)