diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 26955415d..58231ba23 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -105,6 +105,8 @@ def get_model( if model_type == "bert": from lorax_server.models.flash_bert import FlashBert + if config_dict["architectures"][0] == "BertForTokenClassification": + return FlashBert(model_id, revision=revision, dtype=dtype, classifcation_head=True) return FlashBert(model_id, revision=revision, dtype=dtype) if model_type == "distilbert": diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 457be72e3..d18e344f7 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -18,6 +18,12 @@ tracer = trace.get_tracer(__name__) +def _format_prefix(prefix, name): + if prefix is None: + return name + return f"{prefix}.{name}" + + class BertEncoder: def __init__(self, prefix, weights, device, dtype, config: BertConfig): self.layers = [ @@ -31,10 +37,10 @@ def forward(self, hidden_states, cu_seqlens, max_s): class FlashBertModel(torch.nn.Module): - def __init__(self, weights, device, dtype, config: BertConfig): + def __init__(self, prefix, weights, device, dtype, config: BertConfig): super().__init__() - self.embeddings = BertEmbeddings("embeddings", weights, device, dtype, config) - self.encoder = BertEncoder("encoder", weights, device, dtype, config) + self.embeddings = BertEmbeddings(_format_prefix(prefix, "embeddings"), weights, device, dtype, config) + self.encoder = BertEncoder(_format_prefix(prefix, "encoder"), weights, device, dtype, config) def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) @@ -43,12 +49,30 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): return encoder_outputs[cu_seqlens[:-1]] +class FlashBertModelForClassification(torch.nn.Module): + def __init__(self, prefix, weights, device, dtype, config: BertConfig): + super().__init__() + self.embeddings = BertEmbeddings(_format_prefix(prefix, "embeddings"), weights, device, dtype, config) + self.encoder = BertEncoder(_format_prefix(prefix, "encoder"), weights, device, dtype, config) + self.classifier_weight = weights.get_tensor("classifier.weight").to(dtype).to(device) + self.classifier_bias = weights.get_tensor("classifier.bias").to(dtype).to(device) + + def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): + embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) + encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) + batch_size = encoder_outputs.shape[0] // max_s + encoder_outputs = encoder_outputs.reshape(batch_size, max_s, -1) + logits = torch.nn.functional.linear(encoder_outputs, self.classifier_weight, self.classifier_bias) + return logits + + class FlashBert(Model): def __init__( self, model_id: str, revision: Optional[str] = None, dtype: Optional[torch.dtype] = None, + classifcation_head: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -71,9 +95,15 @@ def __init__( dtype, process_group=self.process_group, ) - model = FlashBertModel(weights, device, dtype, config) + prefix = None if (model_id == "WhereIsAI/UAE-Large-V1") else "bert" + if classifcation_head: + model = FlashBertModelForClassification(prefix, weights, device, dtype, config) + else: + model = FlashBertModel(prefix, weights, device, dtype, config) + self.classification_head_enabled = classifcation_head self.hidden_size = config.hidden_size + self.config = config super(FlashBert, self).__init__( model_id=model_id, @@ -98,6 +128,10 @@ def supports_embeddings(self) -> bool: def supports_text_generation(self) -> bool: return False + @property + def supports_classification(self) -> bool: + return self.classification_head_enabled + def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None: # Note: This is meant to 1) preallocate the memory by doing a forward pass # and then just returning the max seqlen since for embeddings we are never generating @@ -125,3 +159,17 @@ def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: cpu_results = embedding.cpu().tolist() return cpu_results + + @tracer.start_as_current_span("classify") + def classify(self, batch: FlashEmbeddingClassificationBatch): + logits: torch.Tensor = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=batch.cu_seqlens, + max_s=batch.max_s, + ) + probabilities = torch.nn.functional.softmax(logits, dim=2) + confidence_scores, predictions = torch.max(probabilities, dim=2) + predicted_token_class = [[self.config.id2label[t.item()] for t in prediction] for prediction in predictions] + return predicted_token_class, confidence_scores.cpu().tolist()