Skip to content

Commit

Permalink
Support BERTModel (first encoder-only embedding model) (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#9056)

Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Co-authored-by: Andrew Feldman <[email protected]>
Co-authored-by: afeldman-nm <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: laishzh <[email protected]>
Co-authored-by: Max de Bayser <[email protected]>
Co-authored-by: Max de Bayser <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
8 people authored Oct 17, 2024
1 parent bb76538 commit 343f8e0
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 15 deletions.
14 changes: 12 additions & 2 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@

from ..utils import check_embeddings_close

# Model, Guard
MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
]

ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
monkeypatch,
hf_runner,
vllm_runner,
example_prompts,
model: str,
model,
dtype: str,
) -> None:
if model in ENCODER_ONLY:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
Expand All @@ -33,7 +43,7 @@ def test_models(
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

check_embeddings_close(
Expand Down
7 changes: 5 additions & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
ENCODER = auto(
) # Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto(
) # Attention between dec. Q and enc. K/V for encoder-decoder


class AttentionBackend(ABC):
Expand Down
59 changes: 50 additions & 9 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,15 @@ def _get_attn_bias(
* Appropriate attention bias value given the attention type
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
else:
# attn_type == AttentionType.ENCODER_DECODER
elif attn_type == AttentionType.ENCODER_DECODER:
return attn_metadata.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _set_attn_bias(
Expand All @@ -313,7 +315,8 @@ def _set_attn_bias(
encoder/decoder cross-attention
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
Expand Down Expand Up @@ -371,6 +374,12 @@ def _get_seq_len_block_table_args(
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."

# No block tables associated with encoder attention
return (attn_metadata.seq_lens_tensor,
attn_metadata.max_prefill_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")

Expand Down Expand Up @@ -479,7 +488,10 @@ def forward(
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
Used for encoder branch of encoder-decoder models.
* ENCODER_ONLY: no kv_caching, uses the normal attention
attributes (seq_lens/seq_lens_tensor/max_seq_len).
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
Expand Down Expand Up @@ -509,6 +521,7 @@ def forward(
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")

elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
Expand Down Expand Up @@ -609,6 +622,8 @@ def forward(
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have prefix attention.")

assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
Expand Down Expand Up @@ -638,6 +653,8 @@ def forward(
output[:num_prefill_tokens] = out

if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")

(
seq_lens_arg,
Expand Down Expand Up @@ -703,36 +720,60 @@ def _run_memory_efficient_xformers_forward(
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None:

# Cross attention block of decoder branch of encoder-decoder
# model uses seq_lens for dec / encoder_seq_lens for enc
if (attn_type == AttentionType.ENCODER_DECODER):
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None

# Default enc/dec cross-attention mask is non-causal
# Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)

# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
elif attn_type == AttentionType.ENCODER:

assert attn_metadata.encoder_seq_lens is not None

# Default encoder self-attention mask is non-causal
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
else:

# Self-attention block of encoder-only model just
# uses the seq_lens directly.
elif attn_type == AttentionType.ENCODER_ONLY:
assert attn_metadata.seq_lens is not None

# Default decoder self-attention mask is causal
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens)

# Self-attention block of decoder branch just
# uses the seq_lens directly
elif attn_type == AttentionType.DECODER:
assert attn_metadata.seq_lens is not None

# Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
else:
raise ValueError("Unknown AttentionType: %s", attn_type)

if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_bias = [attn_bias]
else:
assert attn_type == AttentionType.DECODER
assert attn_metadata.seq_lens is not None
attn_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads, query.dtype,
Expand Down
12 changes: 10 additions & 2 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2


class Pooler(nn.Module):
Expand All @@ -23,12 +24,13 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
pooling_type: The type of pooling to use (LAST, ALL, CLS).
normalize: Whether to normalize the pooled data.
"""

def __init__(self, pooling_type: PoolingType, normalize: bool):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize

Expand All @@ -38,10 +40,16 @@ def forward(
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools specific information from hidden states based on metadata."""

prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens

if self.pooling_type == PoolingType.LAST:
if self.pooling_type is PoolingType.CLS:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
dim=0)[:-1]
pooled_data = hidden_states[first_token_flat_indices]
elif self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
Expand Down
Loading

0 comments on commit 343f8e0

Please sign in to comment.