Skip to content

Commit

Permalink
[Model] Consolidate ViTs attention implementation without mask (#10893)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Dec 4, 2024
1 parent 01d079f commit 10398b4
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 226 deletions.
63 changes: 63 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
Expand Down Expand Up @@ -168,6 +169,68 @@ def extra_repr(self) -> str:
return s


class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz, q_len, _ = query.size()
kv_len = key.size(1)

query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
return out.view(bsz, q_len, -1)


def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
45 changes: 4 additions & 41 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig

from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -22,8 +21,6 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

from .utils import get_vit_attn_backend


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -205,11 +202,8 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
Expand All @@ -220,41 +214,10 @@ def forward(
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.projection(out)

return attn_output, None
Expand Down
46 changes: 4 additions & 42 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPVisionConfig

from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -25,8 +24,6 @@
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData

from .utils import get_vit_attn_backend


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -235,11 +232,8 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
Expand All @@ -250,42 +244,10 @@ def forward(
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)

return attn_output, None
Expand Down
22 changes: 6 additions & 16 deletions vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.nn import LayerNorm

from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -77,27 +78,16 @@ def __init__(
quant_config=quant_config,
)

self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
self.scale)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D

out = torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)

output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))

out = self.attn(q, k, v)
output, _ = self.dense(out)
output = self.output_dropout(output)
return output

Expand Down
25 changes: 4 additions & 21 deletions vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops

from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -141,35 +141,18 @@ def __init__(
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output

Expand Down
28 changes: 4 additions & 24 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand All @@ -25,8 +25,6 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from .utils import get_vit_attn_backend

NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
Expand Down Expand Up @@ -183,10 +181,8 @@ def __init__(
prefix=f"{prefix}.proj",
)

self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
Expand All @@ -209,23 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)

q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)

if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)

out = out.view(B, N, -1)
out = self.attn(q, k, v)
out, _ = self.proj(out)
return out

Expand Down
Loading

0 comments on commit 10398b4

Please sign in to comment.