diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 11f68587d..001cd584a 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -1,10 +1,8 @@ from lorax_server.models.types import GeneratedText import torch -from pathlib import Path from torch import nn from typing import Type, List, Optional -from safetensors import safe_open from transformers.activations import ACT2FN from transformers.models.bert import BertConfig from transformers import AutoTokenizer @@ -40,39 +38,30 @@ class FlashBatch(ABC): def __len__(self): return self.size + def from_pb(self, *args, **kwargs): + return None class BertEmbeddings: def __init__(self, prefix, weights, device, dtype, config: BertConfig): - self.word_embeddings_weight = ( - weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device) - ) + self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device) self.token_type_embeddings_weight = ( - weights.get_tensor(f"{prefix}.token_type_embeddings.weight") .to(dtype) .to(device)) + weights.get_tensor(f"{prefix}.token_type_embeddings.weight").to(dtype).to(device) + ) if config.position_embedding_type == "absolute": self.position_embeddings_weight = ( - weights.get_tensor(f"{prefix}.position_embeddings.weight") - .to(dtype) - .to(device) + weights.get_tensor(f"{prefix}.position_embeddings.weight").to(dtype).to(device) ) else: - raise NotImplementedError( - "FlashBert only supports absolute position embeddings" - ) + raise NotImplementedError("FlashBert only supports absolute position embeddings") - self.layer_norm = FastLayerNorm.load( - prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps - ) + self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps) def forward(self, input_ids, token_type_ids, position_ids): inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight) - token_type_embeds = nn.functional.embedding( - token_type_ids, self.token_type_embeddings_weight - ) - position_embeds = nn.functional.embedding( - position_ids, self.position_embeddings_weight - ) + token_type_embeds = nn.functional.embedding(token_type_ids, self.token_type_embeddings_weight) + position_embeds = nn.functional.embedding(position_ids, self.position_embeddings_weight) inputs_embeds += position_embeds @@ -89,24 +78,14 @@ def __init__(self, prefix, weights, device, dtype, config: BertConfig): value_weight = weights.get_tensor(f"{prefix}.self.value.weight") value_bias = weights.get_tensor(f"{prefix}.self.value.bias") - self.qkv_weight = ( - torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device) - ) - self.qkv_bias = ( - torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device) - ) + self.qkv_weight = torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device) + self.qkv_bias = torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device) - self.dense_weight = ( - weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) - ) - self.dense_bias = ( - weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) - ) + self.dense_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.dense_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) self.layer_norm = FastLayerNorm.load( - prefix=f"{prefix}.output.LayerNorm", - weights=weights, - eps=config.layer_norm_eps + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps ) self.head_size = config.hidden_size // config.num_attention_heads @@ -117,9 +96,7 @@ def forward(self, hidden_states, cu_seqlens, max_s): residual = hidden_states qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) - q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split( - self.num_heads, dim=1 - ) + q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1) attn_output = torch.empty_like(q) attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale) @@ -136,18 +113,10 @@ def forward(self, hidden_states, cu_seqlens, max_s): class BertLayer: def __init__(self, prefix, weights, device, dtype, config: BertConfig): - self.attention = BertAttention( - f"{prefix}.attention", weights, device, dtype, config - ) + self.attention = BertAttention(f"{prefix}.attention", weights, device, dtype, config) - self.intermediate_weight = ( - weights.get_tensor(f"{prefix}.intermediate.dense.weight") - .T.to(dtype) - .to(device) - ) - self.intermediate_bias = ( - weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device) - ) + self.intermediate_weight = weights.get_tensor(f"{prefix}.intermediate.dense.weight").T.to(dtype).to(device) + self.intermediate_bias = weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device) act = config.hidden_act self.intermediate_act_fn = ( @@ -155,31 +124,21 @@ def __init__(self, prefix, weights, device, dtype, config: BertConfig): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) - self.output_weight = ( - weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) - ) - self.output_bias = ( - weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) - ) + self.output_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.output_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) self.layer_norm = FastLayerNorm.load( - prefix=f"{prefix}.output.LayerNorm", - weights=weights, - eps=config.layer_norm_eps + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps ) def forward(self, hidden_states, cu_seqlens, max_s): hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s) residual = hidden_states - hidden_states = torch.addmm( - self.intermediate_bias, hidden_states, self.intermediate_weight - ) + hidden_states = torch.addmm(self.intermediate_bias, hidden_states, self.intermediate_weight) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = torch.addmm( self.output_bias, @@ -193,8 +152,7 @@ def forward(self, hidden_states, cu_seqlens, max_s): class BertEncoder: def __init__(self, prefix, weights, device, dtype, config: BertConfig): self.layers = [ - BertLayer(f"{prefix}.layer.{i}", weights, device, dtype, config) - for i in range(config.num_hidden_layers) + BertLayer(f"{prefix}.layer.{i}", weights, device, dtype, config) for i in range(config.num_hidden_layers) ] def forward(self, hidden_states, cu_seqlens, max_s): @@ -218,11 +176,11 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): class FlashBert(Model): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): + self, + model_id: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -239,10 +197,10 @@ def __init__( config = BertConfig.from_pretrained(model_id) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, + filenames, + device, + dtype, + process_group=self.process_group, ) model = FlashBertModel(weights, device, dtype, config) @@ -262,14 +220,17 @@ def __init__( @property def batch_type(self) -> Type[FlashBatch]: return FlashBatch - + + def warmup(self, batch: FlashBatch, max_new_tokens: int) -> int | None: + # return super().warmup(batch, max_new_tokens) + return 42 # lol + def generate_token(self, batch: FlashBatch) -> torch.Tuple[List[GeneratedText] | FlashBatch | None]: return None - + def forward(self, batch: FlashBatch): return self.embed(batch) - @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> Embedding: embedding = self.model.forward( @@ -281,19 +242,17 @@ def embed(self, batch: FlashBatch) -> Embedding: ) cpu_results = embedding.view(-1).tolist() - return Embedding( - values=cpu_results[:self.hidden_size] - ) - + return Embedding(values=cpu_results[: self.hidden_size]) + def tokenize_to_batch(self, inputs) -> FlashBatch: tokens = self.tokenizer(inputs, return_token_type_ids=True) num_tokens = len(tokens["input_ids"]) position_ids = range(num_tokens) return FlashBatch( - input_ids=torch.tensor(tokens['input_ids'], dtype=torch.int32, device=self.device), - token_type_ids=torch.tensor(tokens['token_type_ids'], dtype=torch.int32, device=self.device), + input_ids=torch.tensor(tokens["input_ids"], dtype=torch.int32, device=self.device), + token_type_ids=torch.tensor(tokens["token_type_ids"], dtype=torch.int32, device=self.device), position_ids=torch.tensor(position_ids, dtype=torch.int32, device=self.device), cu_seqlens=torch.tensor([0, num_tokens], dtype=torch.int32, device=self.device), max_s=num_tokens, size=1, - ) \ No newline at end of file + )