Skip to content

Commit 1ad2dcd

Browse files
brb-nvdominicshanshan
authored andcommitted
[None][chore] Add docs for Gemma3 VLMs (NVIDIA#6880)
Signed-off-by: Balaram Buddharaju <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 60ea9a3 commit 1ad2dcd

File tree

2 files changed

+70
-14
lines changed

2 files changed

+70
-14
lines changed

examples/models/core/multimodal/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ We first describe three runtime modes for running multimodal models and how to r
1212
- [CogVLM](#cogvlm)
1313
- [Deplot](#deplot)
1414
- [Fuyu](#fuyu)
15+
- [Gemma3](#gemma3)
1516
- [InternLM-XComposer2](#internlm-xcomposer2)
1617
- [InternVL2](#internvl2)
1718
- [Kosmos-2](#kosmos-2)
@@ -352,6 +353,75 @@ Currently, CogVLM only support bfloat16 precision.
352353
--engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu
353354
```
354355
356+
## Gemma3
357+
358+
**NOTE: We only support Gemma3 VLMs in Pytorch workflow.**
359+
360+
Gemma3VL decoder requires a custom attention mask while processing images. During the context phase:
361+
- Text tokens attend to other tokens in a causal fashion (standard autoregressive behavior)
362+
- Image tokens attend to other tokens in a causal fashion AND attend to other tokens from the same image in a bidirectional manner
363+
364+
**Reference:** [Gemma3 Model Documentation](https://huggingface.co/docs/transformers/en/model_doc/gemma3)
365+
366+
We support this custom mask with FlashInfer attention backend.
367+
368+
### Requirements
369+
370+
To ensure expected behavior with Gemma3VL, the following configurations are **required**:
371+
- **Attention Backend**: Use the FlashInfer attention backend
372+
- **Chunked Prefill**: Must be disabled
373+
- **KV Cache Reuse**: Must be disabled
374+
375+
### Quick Start
376+
377+
#### 1. Download Model Weights
378+
379+
```bash
380+
export MODEL_NAME="gemma-3-27b-it"
381+
git clone https://huggingface.co/google/${MODEL_NAME}
382+
```
383+
384+
#### 2. Interactive Testing
385+
386+
Use the `quickstart_multimodal.py` script for quick testing:
387+
388+
```bash
389+
python3 examples/llm-api/quickstart_multimodal.py \
390+
--model_dir ${MODEL_NAME}/ \
391+
--modality image \
392+
--image_format pil \
393+
--attention_backend FLASHINFER \
394+
--disable_kv_cache_reuse
395+
```
396+
397+
#### 3. Model Serving
398+
399+
Serve the model using `trtllm-serve` with the required llmapi arguments mentioned in a yaml file:
400+
401+
```bash
402+
# Create the configuration file
403+
cat > extra-llm-api-options.yaml << 'EOF'
404+
cuda_graph_config: null
405+
attn_backend: "FLASHINFER"
406+
enable_chunked_prefill: false
407+
kv_cache_config:
408+
enable_block_reuse: false
409+
EOF
410+
411+
# Serve the model
412+
trtllm-serve ${MODEL_NAME}/ \
413+
--backend pytorch \
414+
--tp_size 1 \
415+
--port 8000 \
416+
--max_batch_size 4 \
417+
--extra_llm_api_options extra-llm-api-options.yaml
418+
```
419+
420+
### Supported Model Variants
421+
422+
Currently supported Gemma3 variants: 4B, 12B, 27B
423+
424+
355425
## InternLM-XComposer2
356426
357427
**NOTE: We only support InternLM-XComposer-VL-7b for now**

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -284,17 +284,3 @@ def _get_image_features(self, pixel_values):
284284
attn_metadata=attn_metadata)[-1]
285285
image_features = self.mm_projector(image_features)
286286
return image_features
287-
288-
289-
def _load_weights_into_hf_module(
290-
model: torch.nn.Module,
291-
weights: dict,
292-
prefix: str,
293-
model_name: str,
294-
) -> None:
295-
filtered_weights = filter_weights(prefix, weights)
296-
missing_keys, _ = model.load_state_dict(filtered_weights)
297-
if len(missing_keys) > 0:
298-
raise KeyError(
299-
f"Missing the following keys for the {model_name} in the checkpoint: "
300-
f"[{', '.join(missing_keys)}].")

0 commit comments

Comments
 (0)