Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix PaliGemma MMP #6930

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why not ReplicatedLinear?

Copy link
Member Author

@ywang96 ywang96 Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't officially supported quantized VLMs so keeping them as nn.Linear for now. When we do we will change these layers for all VLMs in one PR so it's better for us to keep them consistent for now.


def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
hidden_states = self.linear(image_features)
return hidden_states


Expand Down
Loading