diff --git a/examples/models/core/multimodal/README.md b/examples/models/core/multimodal/README.md index a45e3fc724b..6bbfbb954a0 100644 --- a/examples/models/core/multimodal/README.md +++ b/examples/models/core/multimodal/README.md @@ -12,6 +12,7 @@ We first describe three runtime modes for running multimodal models and how to r - [CogVLM](#cogvlm) - [Deplot](#deplot) - [Fuyu](#fuyu) +- [Gemma3](#gemma3) - [InternLM-XComposer2](#internlm-xcomposer2) - [InternVL2](#internvl2) - [Kosmos-2](#kosmos-2) @@ -352,6 +353,75 @@ Currently, CogVLM only support bfloat16 precision. --engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu ``` +## Gemma3 + +**NOTE: We only support Gemma3 VLMs in Pytorch workflow.** + +Gemma3VL decoder requires a custom attention mask while processing images. During the context phase: +- Text tokens attend to other tokens in a causal fashion (standard autoregressive behavior) +- Image tokens attend to other tokens in a causal fashion AND attend to other tokens from the same image in a bidirectional manner + +**Reference:** [Gemma3 Model Documentation](https://huggingface.co/docs/transformers/en/model_doc/gemma3) + +We support this custom mask with FlashInfer attention backend. + +### Requirements + +To ensure expected behavior with Gemma3VL, the following configurations are **required**: +- **Attention Backend**: Use the FlashInfer attention backend +- **Chunked Prefill**: Must be disabled +- **KV Cache Reuse**: Must be disabled + +### Quick Start + +#### 1. Download Model Weights + +```bash +export MODEL_NAME="gemma-3-27b-it" +git clone https://huggingface.co/google/${MODEL_NAME} +``` + +#### 2. Interactive Testing + +Use the `quickstart_multimodal.py` script for quick testing: + +```bash +python3 examples/llm-api/quickstart_multimodal.py \ + --model_dir ${MODEL_NAME}/ \ + --modality image \ + --image_format pil \ + --attention_backend FLASHINFER \ + --disable_kv_cache_reuse +``` + +#### 3. Model Serving + +Serve the model using `trtllm-serve` with the required llmapi arguments mentioned in a yaml file: + +```bash +# Create the configuration file +cat > extra-llm-api-options.yaml << 'EOF' +cuda_graph_config: null +attn_backend: "FLASHINFER" +enable_chunked_prefill: false +kv_cache_config: + enable_block_reuse: false +EOF + +# Serve the model +trtllm-serve ${MODEL_NAME}/ \ + --backend pytorch \ + --tp_size 1 \ + --port 8000 \ + --max_batch_size 4 \ + --extra_llm_api_options extra-llm-api-options.yaml +``` + +### Supported Model Variants + +Currently supported Gemma3 variants: 4B, 12B, 27B + + ## InternLM-XComposer2 **NOTE: We only support InternLM-XComposer-VL-7b for now** diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index b4ddb486cd1..072429c7deb 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -276,17 +276,3 @@ def _get_image_features(self, pixel_values): attn_metadata=attn_metadata)[-1] image_features = self.mm_projector(image_features) return image_features - - -def _load_weights_into_hf_module( - model: torch.nn.Module, - weights: dict, - prefix: str, - model_name: str, -) -> None: - filtered_weights = filter_weights(prefix, weights) - missing_keys, _ = model.load_state_dict(filtered_weights) - if len(missing_keys) > 0: - raise KeyError( - f"Missing the following keys for the {model_name} in the checkpoint: " - f"[{', '.join(missing_keys)}].")