Skip to content

Commit faa2f46

Browse files
authored
[TRTLLM-5059][feat] Enable KV-cache reuse and add E2E tests for llava-next (#7349)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent d49374b commit faa2f46

File tree

7 files changed

+138
-18
lines changed

7 files changed

+138
-18
lines changed

docs/source/reference/multimodal-feature-support-matrix.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
| Model | CUDA Graph | Encoder IFB | KV Cache Reuse | Chunked Prefill |
44
| :----------------- | :--------- | :------------------ | :------------- | :-------------- |
5-
| Gemma 3 | Yes | Yes | No | No |
5+
| Gemma 3 | Yes | Yes | N/A | N/A |
66
| HyperCLOVA | Yes | Yes | No | No |
77
| VILA | Yes | No | No | No |
8-
| LLaVA-NeXT | Yes | Yes | No | No |
8+
| LLaVA-NeXT | Yes | Yes | Yes | No |
99
| Llama 4 | Yes | Yes | No | No |
1010
| Mistral-Small-3.1 | Yes | Yes | No | No |
1111
| Phi-4-multimodal | Yes | Yes | No | No |

examples/llm-api/quickstart_multimodal.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def parse_arguments():
154154
parser = add_lora_args(parser)
155155
args = parser.parse_args()
156156

157-
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
158157
if args.kv_cache_fraction is None:
159158
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
160159

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from ..model_config import ModelConfig
2626
from .modeling_auto import AutoModelForCausalLM
2727
from .modeling_clip import CLIPVisionModel
28-
from .modeling_multimodal_utils import fuse_input_embeds
28+
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
29+
fuse_input_embeds)
2930
from .modeling_utils import (filter_weights, register_auto_model,
3031
register_vision_encoder)
3132

@@ -469,6 +470,8 @@ def forward(
469470
]
470471
else:
471472
mm_embeds = self.mm_encoder.forward(multimodal_params)
473+
mm_embeds = find_uncached_mm_embeds(
474+
mm_embeds, multimodal_params[:num_context_requests])
472475
else:
473476
mm_embeds = [
474477
multimodal_param.multimodal_data["multimodal_embedding"]

tensorrt_llm/inputs/multimodal.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@ class MultimodalInput:
2424
"""
2525

2626
multimodal_positions: List[int]
27-
"""Starting positions of each multimodal chunk in the token sequence.
27+
"""Starting positions of each contiguous multimodal token chunk in the token sequence.
2828
2929
Contains only the start position of each chunk, not all positions of multimodal tokens.
3030
This is different from mm_positions elsewhere which contains all positions.
3131
"""
3232

3333
multimodal_lengths: List[int]
34-
"""Length (number of tokens) of each multimodal item.
34+
"""Length of each contiguous multimodal token chunk, including any special tokens.
3535
36-
Combined with multimodal_positions, this defines the token spans for each multimodal item.
36+
Each span is unique to its multimodal item and may include special tokens for some models,
37+
(e.g., image_end_token, image_break_token for mistral3) mixed with the actual multimodal tokens.
3738
"""
3839

3940
def __post_init__(self):
@@ -485,7 +486,13 @@ def hexdigest_to_int32(hex_digest: str) -> List[int]:
485486

486487
def find_mm_token_lengths(mm_data: Dict[str, Any],
487488
input_processor: Any) -> List[int]:
488-
"""Get multimodal token lengths from multimodal data items. """
489+
"""Get the maximum contiguous multimodal token lengths from multimodal data items.
490+
491+
Returns the total token count for each multimodal item, including any special tokens
492+
(e.g., image_begin, image_end, image_break) that may be mixed with the actual
493+
multimodal content tokens. This mm_token_lengths represents the full contiguous chunk from beginning
494+
to end, not just pure image/video/audio tokens.
495+
"""
489496

490497
mm_items = {
491498
modality: items if isinstance(items, list) else [items]
@@ -528,22 +535,23 @@ def find_mm_token_positions(
528535
num_mm_tokens: List[int],
529536
vocab_size: Optional[int] = None,
530537
mm_token_ids: Optional[torch.Tensor] = None) -> List[int]:
531-
"""Get multimodal token positions using IDs > vocab_size and known lengths.
538+
"""Get starting positions of contiguous multimodal token chunks using known lengths.
539+
540+
This function finds multimodal tokens (with IDs > vocab_size or matching mm_token_ids)
541+
and uses the provided lengths in num_mm_tokens to identify where each contiguous chunk starts.
542+
Each chunk in num_mm_tokens is assumed to be a contiguous block of multimodal tokens for each multimodal item, and may include special tokens (e.g., image_begin, image_end, image_break) within the chunk.
532543
533-
This function finds multimodal tokens (with IDs > vocab_size) and uses the
534-
provided lengths in num_mm_tokens to identify where each chunk starts.
535-
This works even when there are no gaps between different image sequences
536-
(e.g., when all images use the same token IDs).
537-
Note at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids is provided, vocab_size is ignored.
544+
Note: at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids
545+
is provided, vocab_size is ignored.
538546
539547
Args:
540548
input_ids: Token sequence (tensor, list, or numpy array)
541-
num_mm_tokens: List of lengths for each multimodal token chunk
542-
vocab_size: Size of the model's vocabulary
543-
mm_token_ids: Possible token ids for multimodal tokens
549+
num_mm_tokens: List of contiguous chunk lengths for each multimodal item
550+
vocab_size: Size of the model's vocabulary (used to identify tokens > vocab_size)
551+
mm_token_ids: Specific token IDs that represent multimodal tokens
544552
545553
Returns:
546-
List of starting positions for each multimodal token chunk
554+
List of starting positions for each contiguous multimodal token chunk
547555
"""
548556
if mm_token_ids is None and vocab_size is None:
549557
raise ValueError(

tensorrt_llm/inputs/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def get_vocab_size(self) -> Optional[int]:
8585

8686
def get_mm_token_ids(self) -> Optional[Tensor]:
8787
"""Return multimodal token IDs if available; otherwise None.
88+
89+
The token IDs filtered by this method should be contiguous for each multimodal item, i.e. special tokens if any should be included.
8890
"""
8991
processor = self.get_processor()
9092
if processor is not None and getattr(processor, 'mm_token_ids',

tests/integration/defs/test_e2e.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,8 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
24252425
*accuracy_inputs[modality]["prompt"],
24262426
"--media",
24272427
*accuracy_inputs[modality]["media"],
2428+
# TODO: remove this once kv cache reuse is supported for all VLM models
2429+
"--disable_kv_cache_reuse",
24282430
]
24292431
# NOTE: Qwen2-VL and Qwen2-5-VL model need larger max_num_tokens for video.
24302432
if model_name in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
@@ -2510,6 +2512,96 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
25102512
_check_mem_usage(running_log, [peak, 0, 0, 0])
25112513

25122514

2515+
@pytest.mark.parametrize("modality", ["image", "video"])
2516+
@pytest.mark.parametrize("model_name,model_path", [
2517+
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
2518+
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
2519+
])
2520+
def test_ptp_quickstart_multimodal_kv_cache_reuse(llm_root, llm_venv,
2521+
model_name, model_path,
2522+
modality):
2523+
# NOTE: individual tests need to be enabled in
2524+
# tests/integration/test_lists/qa/examples_test_list.txt
2525+
2526+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
2527+
test_data_root = Path(
2528+
os.path.join(llm_models_root(), "multimodals", "test_data"))
2529+
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
2530+
if modality == "video" and model_name == "llava-v1.6-mistral-7b":
2531+
pytest.skip("Skipping video modality test for llava-v1.6-mistral-7b")
2532+
2533+
num_same_requests = 3 # test kv cache reuse with multiple same requests
2534+
accuracy_inputs = {
2535+
"image": {
2536+
"prompt": [
2537+
"Describe the natural environment in the image.",
2538+
] * num_same_requests,
2539+
"media": [
2540+
str(test_data_root / "seashore.png"),
2541+
] * num_same_requests,
2542+
},
2543+
"video": {
2544+
"prompt": [
2545+
"Tell me what you see in the video briefly.",
2546+
] * num_same_requests,
2547+
"media": [
2548+
str(test_data_root / "OAI-sora-tokyo-walk.mp4"),
2549+
] * num_same_requests,
2550+
},
2551+
}
2552+
2553+
expected_keywords = {
2554+
"llava-v1.6-mistral-7b": {
2555+
"image": [
2556+
["ocean", "sky", "large", "waves", "shore", "blue"],
2557+
] * num_same_requests,
2558+
},
2559+
"qwen2.5-vl-7b-instruct": {
2560+
"image": [
2561+
["dramatic", "moody", "ocean", "stormy", "sky", "waves"],
2562+
] * num_same_requests,
2563+
"video": [
2564+
["woman", "neon", "night", "jacket", "wet"],
2565+
] * num_same_requests,
2566+
},
2567+
}
2568+
2569+
cmd = [
2570+
str(example_root / "quickstart_multimodal.py"),
2571+
"--model_dir",
2572+
f"{llm_models_root()}/{model_path}",
2573+
"--modality",
2574+
modality,
2575+
"--prompt",
2576+
*accuracy_inputs[modality]["prompt"],
2577+
"--media",
2578+
*accuracy_inputs[modality]["media"],
2579+
"--max_batch_size", # single request at a time to test kv cache reuse
2580+
"1",
2581+
]
2582+
# NOTE: Qwen2-VL and Qwen2-5-VL model need larger max_num_tokens for video.
2583+
if model_name in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
2584+
] and modality == "video":
2585+
cmd.append("--max_num_tokens=16384")
2586+
2587+
output = llm_venv.run_cmd(cmd, caller=check_output)
2588+
match_ratio = 4.0 / 5
2589+
for prompt_output, prompt_keywords in zip(
2590+
parse_output(output), expected_keywords[model_name][modality]):
2591+
matches = [
2592+
keyword in prompt_output.lower() for keyword in prompt_keywords
2593+
]
2594+
obs_match_ratio = 1. * sum(matches) / len(matches)
2595+
print(
2596+
f"Prompt output: {prompt_output}\nExpected keywords: {prompt_keywords}\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} given threshold {match_ratio}"
2597+
)
2598+
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
2599+
# TODO: Setting max_batch_size=1 and repeating the same request helps test KV cache reuse indirectly,
2600+
# but does not directly measure the KV cache hit rate. For a more direct test, we would need to enable
2601+
# return_perf_metrics=True, which is not currently supported by the quickstart example CLI.
2602+
print("All answers are correct!")
2603+
2604+
25132605
@pytest.mark.parametrize("modality", ["image", "audio", "image_audio"])
25142606
def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
25152607
model_name = "Phi-4-multimodal-instruct"
@@ -2583,6 +2675,8 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
25832675
"--load_lora",
25842676
"--auto_model_name",
25852677
"Phi4MMForCausalLM",
2678+
# TODO: remove this once kv cache reuse is supported for Phi-4-multimodal
2679+
"--disable_kv_cache_reuse",
25862680
]
25872681
output = llm_venv.run_cmd(cmd, caller=check_output)
25882682

@@ -2683,7 +2777,12 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
26832777
cmd.append("--max_seq_len=4096")
26842778
cmd.append("--load_lora")
26852779
cmd.append("--auto_model_name")
2780+
# TODO: remove this once kv cache reuse is supported for Phi-4-multimodal
2781+
cmd.append("--disable_kv_cache_reuse")
26862782
cmd.append("Phi4MMForCausalLM")
2783+
elif model_name == "mistral-small-3.1-24b-instruct":
2784+
# TODO: remove this once kv cache reuse is supported for Mistral
2785+
cmd.append("--disable_kv_cache_reuse")
26872786

26882787
output = llm_venv.run_cmd(cmd, caller=check_output)
26892788

@@ -2784,6 +2883,12 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
27842883
cmd.append("--load_lora")
27852884
cmd.append("--auto_model_name")
27862885
cmd.append("Phi4MMForCausalLM")
2886+
# TODO: remove this once kv cache reuse is supported for Phi-4
2887+
cmd.append("--disable_kv_cache_reuse")
2888+
2889+
elif model_name == "mistral-small-3.1-24b-instruct":
2890+
# TODO: remove this once kv cache reuse is supported for Mistral
2891+
cmd.append("--disable_kv_cache_reuse")
27872892

27882893
output = llm_venv.run_cmd(cmd, caller=check_output)
27892894
print("output:", output)

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistr
631631
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
632632
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
633633
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
634+
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image]
635+
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image]
636+
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video]
634637
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
635638
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
636639
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]

0 commit comments

Comments
 (0)