Skip to content

Commit 14f8cdd

Browse files
[Feature] add mm token usage (#4570)
* add mm token usage * fix unit test * fix unit test * fix unit test * fix model path * fix unit test * fix unit test * fix unit test * remove uncomment * change var name * fix code style * fix code style * fix code style * fix code style * fix unit test
1 parent fc5cd1a commit 14f8cdd

File tree

9 files changed

+70
-20
lines changed

9 files changed

+70
-20
lines changed

fastdeploy/engine/request.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ class RequestOutput:
447447
encoder_prompt_token_ids: The token IDs of the encoder prompt.
448448
None if decoder-only.
449449
num_cached_tokens: The number of tokens with prefix cache hit.
450+
num_input_image_tokens: The number of input image tokens.
451+
num_input_video_tokens: The number of input video tokens.
450452
"""
451453

452454
def __init__(
@@ -459,6 +461,8 @@ def __init__(
459461
finished: bool = False,
460462
metrics: Optional[RequestMetrics] = None,
461463
num_cached_tokens: Optional[int] = 0,
464+
num_input_image_tokens: Optional[int] = 0,
465+
num_input_video_tokens: Optional[int] = 0,
462466
error_code: Optional[int] = 200,
463467
error_msg: Optional[str] = None,
464468
) -> None:
@@ -470,6 +474,8 @@ def __init__(
470474
self.finished = finished
471475
self.metrics = metrics
472476
self.num_cached_tokens = num_cached_tokens
477+
self.num_input_image_tokens = num_input_image_tokens
478+
self.num_input_video_tokens = num_input_video_tokens
473479
self.error_code = error_code
474480
self.error_msg = error_msg
475481

@@ -512,6 +518,8 @@ def __repr__(self) -> str:
512518
f"outputs={self.outputs}, "
513519
f"finished={self.finished}, "
514520
f"num_cached_tokens={self.num_cached_tokens}, "
521+
f"num_input_image_tokens={self.num_input_image_tokens}, "
522+
f"num_input_video_tokens={self.num_input_video_tokens}, "
515523
f"metrics={self.metrics}, "
516524
)
517525

@@ -534,6 +542,8 @@ def to_dict(self):
534542
"metrics": None if self.metrics is None else self.metrics.to_dict(),
535543
"finished": self.finished,
536544
"num_cached_tokens": self.num_cached_tokens,
545+
"num_input_image_tokens": self.num_input_image_tokens,
546+
"num_input_video_tokens": self.num_input_video_tokens,
537547
"error_code": self.error_code,
538548
"error_msg": self.error_msg,
539549
}

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ async def chat_completion_stream_generator(
276276
if first_iteration:
277277
num_prompt_tokens = len(prompt_token_ids)
278278
num_cached_tokens = res.get("num_cached_tokens", 0)
279+
num_input_image_tokens = res.get("num_input_image_tokens", 0)
280+
num_input_video_tokens = res.get("num_input_video_tokens", 0)
279281
for i in range(num_choices):
280282
choice = ChatCompletionResponseStreamChoice(
281283
index=i,
@@ -312,7 +314,11 @@ async def chat_completion_stream_generator(
312314
prompt_tokens=num_prompt_tokens,
313315
completion_tokens=0,
314316
total_tokens=num_prompt_tokens,
315-
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens),
317+
prompt_tokens_details=PromptTokenUsageInfo(
318+
cached_tokens=num_cached_tokens,
319+
image_tokens=num_input_image_tokens,
320+
video_tokens=num_input_video_tokens,
321+
),
316322
completion_tokens_details=CompletionTokenUsageInfo(reasoning_tokens=0),
317323
)
318324
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
@@ -476,6 +482,8 @@ async def chat_completion_full_generator(
476482
draft_logprob_contents = [[] for _ in range(num_choices)]
477483
completion_token_ids = [[] for _ in range(num_choices)]
478484
num_cached_tokens = [0] * num_choices
485+
num_input_image_tokens = [0] * num_choices
486+
num_input_video_tokens = [0] * num_choices
479487
num_image_tokens = [0] * num_choices
480488
response_processor = ChatResponseProcessor(
481489
data_processor=self.engine_client.data_processor,
@@ -546,14 +554,15 @@ async def chat_completion_full_generator(
546554
previous_num_tokens[idx] += data["outputs"].get("image_token_num")
547555
num_image_tokens[idx] = data["outputs"].get("image_token_num")
548556
choice = await self._create_chat_completion_choice(
549-
output=output,
550-
index=idx,
557+
data=data,
551558
request=request,
552-
previous_num_tokens=previous_num_tokens[idx],
553559
prompt_token_ids=prompt_token_ids,
554560
prompt_tokens=prompt_tokens,
555561
completion_token_ids=completion_token_ids[idx],
562+
previous_num_tokens=previous_num_tokens[idx],
556563
num_cached_tokens=num_cached_tokens,
564+
num_input_image_tokens=num_input_image_tokens,
565+
num_input_video_tokens=num_input_video_tokens,
557566
num_image_tokens=num_image_tokens,
558567
logprob_contents=logprob_contents,
559568
response_processor=response_processor,
@@ -571,11 +580,16 @@ async def chat_completion_full_generator(
571580
prompt_tokens=num_prompt_tokens,
572581
completion_tokens=num_generated_tokens,
573582
total_tokens=num_prompt_tokens + num_generated_tokens,
574-
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=sum(num_cached_tokens)),
583+
prompt_tokens_details=PromptTokenUsageInfo(
584+
cached_tokens=sum(num_cached_tokens),
585+
image_tokens=sum(num_input_image_tokens),
586+
video_tokens=sum(num_input_video_tokens),
587+
),
575588
completion_tokens_details=CompletionTokenUsageInfo(
576589
reasoning_tokens=num_reasoning_tokens, image_tokens=sum(num_image_tokens)
577590
),
578591
)
592+
579593
choices = sorted(choices, key=lambda x: x.index)
580594
res = ChatCompletionResponse(
581595
id=request_id,
@@ -589,18 +603,21 @@ async def chat_completion_full_generator(
589603

590604
async def _create_chat_completion_choice(
591605
self,
592-
output: dict,
593-
index: int,
606+
data: dict,
594607
request: ChatCompletionRequest,
595-
previous_num_tokens: int,
596608
prompt_token_ids: list,
597609
prompt_tokens: str,
598610
completion_token_ids: list,
611+
previous_num_tokens: int,
599612
num_cached_tokens: list,
613+
num_input_image_tokens: list,
614+
num_input_video_tokens: list,
600615
num_image_tokens: list,
601616
logprob_contents: list,
602617
response_processor: ChatResponseProcessor,
603618
) -> ChatCompletionResponseChoice:
619+
idx = int(data["request_id"].split("_")[-1])
620+
output = data["outputs"]
604621

605622
if output is not None and output.get("metrics") and output["metrics"].get("request_start_time"):
606623
work_process_metrics.e2e_request_latency.observe(
@@ -621,13 +638,15 @@ async def _create_chat_completion_choice(
621638
message.content = output["text"]
622639

623640
logprobs_full_res = None
624-
if logprob_contents[index]:
625-
logprobs_full_res = LogProbs(content=logprob_contents[index])
641+
if logprob_contents[idx]:
642+
logprobs_full_res = LogProbs(content=logprob_contents[idx])
626643

627644
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
628645
max_tokens = request.max_completion_tokens or request.max_tokens
629-
num_cached_tokens[index] = output.get("num_cached_tokens", 0)
630-
num_image_tokens[index] = output.get("num_image_tokens", 0)
646+
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
647+
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
648+
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
649+
num_image_tokens[idx] = output.get("num_image_tokens", 0)
631650

632651
finish_reason = "stop"
633652
if has_no_token_limit or previous_num_tokens != max_tokens:
@@ -640,7 +659,7 @@ async def _create_chat_completion_choice(
640659
finish_reason = "recover_stop"
641660

642661
return ChatCompletionResponseChoice(
643-
index=index,
662+
index=idx,
644663
message=message,
645664
logprobs=logprobs_full_res,
646665
finish_reason=finish_reason,

fastdeploy/input/ernie4_5_vl_processor/process.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N
193193
"labels": [],
194194
"cur_position": 0,
195195
"video_cnt": 0,
196+
"num_input_image_tokens": 0,
197+
"num_input_video_tokens": 0,
196198
"mm_positions": [],
197199
"mm_hashes": [],
198200
}
@@ -357,6 +359,7 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
357359
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
358360
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
359361
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
362+
outputs["num_input_image_tokens"] += num_tokens
360363

361364
pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"])
362365
outputs["position_ids"].extend(pos_ids)
@@ -428,6 +431,7 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
428431
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
429432
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
430433
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
434+
outputs["num_input_video_tokens"] += num_tokens
431435

432436
pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"])
433437
outputs["position_ids"].extend(pos_ids)

fastdeploy/input/paddleocr_vl_processor/process.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N
143143
"labels": [],
144144
"cur_position": 0,
145145
"video_cnt": 0,
146+
"num_input_image_tokens": 0,
147+
"num_input_video_tokens": 0,
146148
"fps": [],
147149
"mm_positions": [],
148-
"mm_hashes": [],
149150
"vit_seqlen": [],
150151
"vit_position_ids": [],
151152
}
@@ -354,6 +355,7 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
354355
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
355356
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
356357
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
358+
outputs["num_input_image_tokens"] += int(num_tokens)
357359

358360
outputs["images"].append(ret["pixel_values"])
359361
if not uuid:
@@ -414,6 +416,7 @@ def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) ->
414416
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
415417
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
416418
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
419+
outputs["num_input_video_tokens"] += int(num_tokens)
417420

418421
outputs["images"].append(ret["pixel_values"])
419422
if not uuid:

fastdeploy/input/qwen_vl_processor/process.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N
142142
"labels": [],
143143
"cur_position": 0,
144144
"video_cnt": 0,
145+
"num_input_image_tokens": 0,
146+
"num_input_video_tokens": 0,
145147
"fps": [],
146148
"mm_positions": [],
147149
"mm_hashes": [],
@@ -351,6 +353,7 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
351353
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
352354
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
353355
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
356+
outputs["num_input_image_tokens"] += int(num_tokens)
354357

355358
outputs["images"].append(ret["pixel_values"])
356359
if not uuid:
@@ -409,6 +412,7 @@ def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) ->
409412
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
410413
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
411414
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
415+
outputs["num_input_video_tokens"] += int(num_tokens)
412416

413417
outputs["images"].append(ret["pixel_values"])
414418
if not uuid:

fastdeploy/output/token_processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def _process_batch_output_use_zmq(self, receive_datas):
289289
if task.messages is not None:
290290
result.prompt = task.messages
291291
result.num_cached_tokens = task.num_cached_tokens
292+
if task.get("multimodal_inputs", None):
293+
result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0)
294+
result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0)
292295

293296
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
294297
result = self._process_per_token(task, i, token_ids, result, is_prefill)
@@ -655,6 +658,9 @@ def _process_batch_output(self):
655658
if task.messages is not None:
656659
result.prompt = task.messages
657660
result.num_cached_tokens = task.num_cached_tokens
661+
if task.get("multimodal_inputs", None):
662+
result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0)
663+
result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0)
658664

659665
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
660666

tests/ce/server/test_logprobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_unstream_with_logprobs():
3232
"bytes": [231, 137, 155, 233, 161, 191],
3333
"top_logprobs": None,
3434
}
35+
3536
assert resp_json["usage"]["prompt_tokens"] == 22
3637
assert resp_json["usage"]["completion_tokens"] == 3
3738
assert resp_json["usage"]["total_tokens"] == 25

tests/entrypoints/openai/test_max_streaming_tokens.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,10 @@ async def test_create_chat_completion_choice(self):
387387
"text": "Normal AI response",
388388
"reasoning_content": "Normal reasoning",
389389
"tool_call": None,
390-
"num_cached_tokens": 3,
391390
"num_image_tokens": 2,
392391
"raw_prediction": "raw_answer_0",
393392
},
393+
"num_cached_tokens": 3,
394394
"finished": True,
395395
"previous_num_tokens": 2,
396396
},
@@ -416,10 +416,10 @@ async def test_create_chat_completion_choice(self):
416416
"text": "Edge case response",
417417
"reasoning_content": None,
418418
"tool_call": None,
419-
"num_cached_tokens": 0,
420419
"num_image_tokens": 0,
421420
"raw_prediction": None,
422421
},
422+
"num_cached_tokens": 0,
423423
"finished": True,
424424
"previous_num_tokens": 1,
425425
},
@@ -446,18 +446,21 @@ async def test_create_chat_completion_choice(self):
446446
mock_response_processor.enable_multimodal_content.return_value = False
447447
completion_token_ids = [[], []]
448448
num_cached_tokens = [0, 0]
449+
num_input_image_tokens = [0, 0]
450+
num_input_video_tokens = [0, 0]
449451
num_image_tokens = [0, 0]
450452

451453
for idx, case in enumerate(test_cases):
452454
actual_choice = await self.chat_serving._create_chat_completion_choice(
453-
output=case["test_data"]["outputs"],
454-
index=idx,
455+
data=case["test_data"],
455456
request=case["mock_request"],
456-
previous_num_tokens=case["test_data"]["previous_num_tokens"],
457457
prompt_token_ids=prompt_token_ids,
458458
prompt_tokens=prompt_tokens,
459459
completion_token_ids=completion_token_ids[idx],
460+
previous_num_tokens=case["test_data"]["previous_num_tokens"],
460461
num_cached_tokens=num_cached_tokens,
462+
num_input_image_tokens=num_input_image_tokens,
463+
num_input_video_tokens=num_input_video_tokens,
461464
num_image_tokens=num_image_tokens,
462465
logprob_contents=logprob_contents,
463466
response_processor=mock_response_processor,

tests/output/test_process_batch_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self):
5050
def get(self, key: str, default_value=None):
5151
if hasattr(self, key):
5252
return getattr(self, key)
53-
elif hasattr(self.sampling_params, key):
53+
elif hasattr(self, "sampling_params") and hasattr(self.sampling_params, key):
5454
return getattr(self.sampling_params, key)
5555
else:
5656
return default_value

0 commit comments

Comments
 (0)