diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e8c5786066170..7a2c95594ddcd 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -606,7 +606,6 @@ def forward( :class:`LlavaNextImageInputs` """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -618,9 +617,14 @@ def forward( self.language_model.model.get_input_embeddings, lambda _: self._process_image_input(image_input), ) - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a526a5dccd398..e7088edb97b2b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -564,8 +564,13 @@ def forward( vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + output = self.llm( - input_ids=None, + input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3c34227767e05..ba798833e26a9 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,6 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import _Backend +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -713,6 +714,7 @@ def forward( return image_features +@support_torch_compile class MolmoModel(nn.Module): def __init__( @@ -1141,7 +1143,6 @@ def forward( **kwargs: object, ) -> SamplerOutput: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -1156,10 +1157,13 @@ def forward( image_input["image_input_idx"], image_input["seq_len"], ) - - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.model.embed_tokens(input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.model( input_ids=input_ids,