Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove intermediate states copying in Mllama #617

Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,20 +509,26 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
encoder_states = torch.empty(
(len(self.output_hidden_states), hidden_states.size(0),
hidden_states.size(1), hidden_states.size(2)),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states_idx = 0

for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
encoder_states[hidden_states_idx] = hidden_states
hidden_states_idx += 1
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)

if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
encoder_states[hidden_states_idx] = hidden_states
jkaniecki marked this conversation as resolved.
Show resolved Hide resolved

return hidden_states, encoder_states
return hidden_states, encoder_states.permute(1, 2, 3, 0)


class MllamaVisionModel(nn.Module):
Expand Down Expand Up @@ -658,8 +664,6 @@ def forward(self, pixel_values: torch.Tensor,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states,
dim=-1)

# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
Expand Down
Loading