From a8ca5cb7e8a5cf859b603a080c62620e8e3163e8 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 6 Sep 2024 12:51:47 -0700 Subject: [PATCH] Support FlashInfer for BERT (#597) --- .../custom_modeling/flash_bert_modeling.py | 4 +- .../models/custom_modeling/siglip.py | 2 +- server/lorax_server/models/flash_bert.py | 62 ++++++++++++++----- server/lorax_server/models/vlm_causal_lm.py | 2 +- server/lorax_server/utils/flash_attn.py | 1 + 5 files changed, 52 insertions(+), 19 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_bert_modeling.py b/server/lorax_server/models/custom_modeling/flash_bert_modeling.py index 3f5696f83..17595f96b 100644 --- a/server/lorax_server/models/custom_modeling/flash_bert_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_bert_modeling.py @@ -58,7 +58,7 @@ def forward(self, hidden_states, cu_seqlens, max_s): qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1) - attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale) + attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale, causal=False) hidden_states = torch.addmm( self.dense_bias, @@ -162,7 +162,7 @@ def forward(self, hidden_states, cu_seqlens, max_s): qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1) - attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale) + attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale, causal=False) hidden_states = torch.addmm( self.dense_bias, diff --git a/server/lorax_server/models/custom_modeling/siglip.py b/server/lorax_server/models/custom_modeling/siglip.py index 00c81e8f6..02f1309e4 100644 --- a/server/lorax_server/models/custom_modeling/siglip.py +++ b/server/lorax_server/models/custom_modeling/siglip.py @@ -50,7 +50,7 @@ def __init__(self, prefix, config: SiglipVisionConfig, weights): def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: # TODO(travis): check why this is necessary pixel_values = pixel_values.to(self.patch_embedding.weight.dtype) - + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index ab8548615..8c92085b6 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -1,4 +1,5 @@ -from typing import Optional, Type +from contextlib import nullcontext +from typing import Any, ContextManager, Optional, Type import torch from opentelemetry import trace @@ -14,6 +15,7 @@ initialize_torch_distributed, weight_files, ) +from lorax_server.utils.state import FLASH_INFER tracer = trace.get_tracer(__name__) @@ -109,6 +111,15 @@ def __init__( self.hidden_size = config.hidden_size self.config = config + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_attention_heads + self.head_size = config.hidden_size // config.num_attention_heads + + if FLASH_INFER: + from lorax_server.utils.flashinfer_attention import create_prefill_state + + self.prefill_state = create_prefill_state(device=device) + super(FlashBert, self).__init__( model_id=model_id, model=model, @@ -147,18 +158,38 @@ def generate_token(self, batch: FlashEmbeddingClassificationBatch) -> None: raise NotImplementedError("This model does not support text generation") return None + def _forward_context( + self, + *, + cu_seqlens: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from lorax_server.utils.flashinfer_attention import use_prefill_state + + return use_prefill_state( + state=(state if state is not None else self.prefill_state), + cu_seqlens=cu_seqlens, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + def forward(self, batch: FlashEmbeddingClassificationBatch): return self.embed(batch) @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: - embedding: torch.Tensor = self.model.forward( - input_ids=batch.input_ids, - token_type_ids=batch.token_type_ids, - position_ids=batch.position_ids, - cu_seqlens=batch.cu_seqlens, - max_s=batch.max_s, - ) + with self._forward_context(cu_seqlens=batch.cu_seqlens): + embedding: torch.Tensor = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=batch.cu_seqlens, + max_s=batch.max_s, + ) embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] cpu_results = embedding.cpu().tolist() @@ -166,13 +197,14 @@ def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: @tracer.start_as_current_span("classify") def classify(self, batch: FlashEmbeddingClassificationBatch): - logits: torch.Tensor = self.model.forward( - input_ids=batch.input_ids, - token_type_ids=batch.token_type_ids, - position_ids=batch.position_ids, - cu_seqlens=batch.cu_seqlens, - max_s=batch.max_s, - ) + with self._forward_context(cu_seqlens=batch.cu_seqlens): + logits: torch.Tensor = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=batch.cu_seqlens, + max_s=batch.max_s, + ) probabilities = torch.nn.functional.softmax(logits, dim=2) confidence_scores, predictions = torch.max(probabilities, dim=2) predicted_token_class = [[self.config.id2label[t.item()] for t in prediction] for prediction in predictions] diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index dbeacd964..5c529ac92 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -269,7 +269,7 @@ def __init__( raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} - + processor = processor_class.from_pretrained( model_id, revision=revision, diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index e34999a83..ea1eeac5e 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -137,6 +137,7 @@ def attention( k, v, causal=causal, + pos_encoding_mode="NONE", window_left=window_size_left, logits_soft_cap=softcap, sm_scale=softmax_scale,