Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Clean up MiniCPM-V #6939

Merged
merged 25 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ Vision Language Models
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- :code:`openbmb/MiniCPM-V-2(Incoming...)`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
HwwwwwwwH marked this conversation as resolved.
Show resolved Hide resolved
-

.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork(:code:`HwwwH/MiniCPM-V-2`) for now.
HwwwwwwwH marked this conversation as resolved.
Show resolved Hide resolved
For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
HwwwwwwwH marked this conversation as resolved.
Show resolved Hide resolved

----

If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,9 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
input_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
input_embeds)
attn_metadata, intermediate_tensors)
return model_output

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
Expand Down Expand Up @@ -463,11 +464,10 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, input_embeds)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
Loading
Loading