Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/models/core/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ We first describe three runtime modes for running multimodal models and how to r
- [CogVLM](#cogvlm)
- [Deplot](#deplot)
- [Fuyu](#fuyu)
- [Gemma3](#gemma3)
- [InternLM-XComposer2](#internlm-xcomposer2)
- [InternVL2](#internvl2)
- [Kosmos-2](#kosmos-2)
Expand Down Expand Up @@ -352,6 +353,75 @@ Currently, CogVLM only support bfloat16 precision.
--engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu
```

## Gemma3

**NOTE: We only support Gemma3 VLMs in Pytorch workflow.**

Gemma3VL decoder requires a custom attention mask while processing images. During the context phase:
- Text tokens attend to other tokens in a causal fashion (standard autoregressive behavior)
- Image tokens attend to other tokens in a causal fashion AND attend to other tokens from the same image in a bidirectional manner

**Reference:** [Gemma3 Model Documentation](https://huggingface.co/docs/transformers/en/model_doc/gemma3)

We support this custom mask with FlashInfer attention backend.

### Requirements

To ensure expected behavior with Gemma3VL, the following configurations are **required**:
- **Attention Backend**: Use the FlashInfer attention backend
- **Chunked Prefill**: Must be disabled
- **KV Cache Reuse**: Must be disabled

### Quick Start

#### 1. Download Model Weights

```bash
export MODEL_NAME="gemma-3-27b-it"
git clone https://huggingface.co/google/${MODEL_NAME}
```

#### 2. Interactive Testing

Use the `quickstart_multimodal.py` script for quick testing:

```bash
python3 examples/llm-api/quickstart_multimodal.py \
--model_dir ${MODEL_NAME}/ \
--modality image \
--image_format pil \
--attention_backend FLASHINFER \
--disable_kv_cache_reuse
```

#### 3. Model Serving

Serve the model using `trtllm-serve` with the required llmapi arguments mentioned in a yaml file:

```bash
# Create the configuration file
cat > extra-llm-api-options.yaml << 'EOF'
cuda_graph_config: null
attn_backend: "FLASHINFER"
enable_chunked_prefill: false
kv_cache_config:
enable_block_reuse: false
EOF

# Serve the model
trtllm-serve ${MODEL_NAME}/ \
--backend pytorch \
--tp_size 1 \
--port 8000 \
--max_batch_size 4 \
--extra_llm_api_options extra-llm-api-options.yaml
```

### Supported Model Variants

Currently supported Gemma3 variants: 4B, 12B, 27B


## InternLM-XComposer2

**NOTE: We only support InternLM-XComposer-VL-7b for now**
Expand Down
14 changes: 0 additions & 14 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,3 @@ def _get_image_features(self, pixel_values):
attn_metadata=attn_metadata)[-1]
image_features = self.mm_projector(image_features)
return image_features


def _load_weights_into_hf_module(
model: torch.nn.Module,
weights: dict,
prefix: str,
model_name: str,
) -> None:
filtered_weights = filter_weights(prefix, weights)
missing_keys, _ = model.load_state_dict(filtered_weights)
if len(missing_keys) > 0:
raise KeyError(
f"Missing the following keys for the {model_name} in the checkpoint: "
f"[{', '.join(missing_keys)}].")