Skip to content

Commit

Permalink
add support for classification in bert (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Aug 14, 2024
1 parent 2e5d97c commit a3f91a6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
2 changes: 2 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def get_model(
if model_type == "bert":
from lorax_server.models.flash_bert import FlashBert

if config_dict["architectures"][0] == "BertForTokenClassification":
return FlashBert(model_id, revision=revision, dtype=dtype, classifcation_head=True)
return FlashBert(model_id, revision=revision, dtype=dtype)

if model_type == "distilbert":
Expand Down
56 changes: 52 additions & 4 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
tracer = trace.get_tracer(__name__)


def _format_prefix(prefix, name):
if prefix is None:
return name
return f"{prefix}.{name}"


class BertEncoder:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
self.layers = [
Expand All @@ -31,10 +37,10 @@ def forward(self, hidden_states, cu_seqlens, max_s):


class FlashBertModel(torch.nn.Module):
def __init__(self, weights, device, dtype, config: BertConfig):
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
super().__init__()
self.embeddings = BertEmbeddings("embeddings", weights, device, dtype, config)
self.encoder = BertEncoder("encoder", weights, device, dtype, config)
self.embeddings = BertEmbeddings(_format_prefix(prefix, "embeddings"), weights, device, dtype, config)
self.encoder = BertEncoder(_format_prefix(prefix, "encoder"), weights, device, dtype, config)

def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
Expand All @@ -43,12 +49,30 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
return encoder_outputs[cu_seqlens[:-1]]


class FlashBertModelForClassification(torch.nn.Module):
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
super().__init__()
self.embeddings = BertEmbeddings(_format_prefix(prefix, "embeddings"), weights, device, dtype, config)
self.encoder = BertEncoder(_format_prefix(prefix, "encoder"), weights, device, dtype, config)
self.classifier_weight = weights.get_tensor("classifier.weight").to(dtype).to(device)
self.classifier_bias = weights.get_tensor("classifier.bias").to(dtype).to(device)

def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s)
batch_size = encoder_outputs.shape[0] // max_s
encoder_outputs = encoder_outputs.reshape(batch_size, max_s, -1)
logits = torch.nn.functional.linear(encoder_outputs, self.classifier_weight, self.classifier_bias)
return logits


class FlashBert(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
classifcation_head: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand All @@ -71,9 +95,15 @@ def __init__(
dtype,
process_group=self.process_group,
)
model = FlashBertModel(weights, device, dtype, config)
prefix = None if (model_id == "WhereIsAI/UAE-Large-V1") else "bert"
if classifcation_head:
model = FlashBertModelForClassification(prefix, weights, device, dtype, config)
else:
model = FlashBertModel(prefix, weights, device, dtype, config)

self.classification_head_enabled = classifcation_head
self.hidden_size = config.hidden_size
self.config = config

super(FlashBert, self).__init__(
model_id=model_id,
Expand All @@ -98,6 +128,10 @@ def supports_embeddings(self) -> bool:
def supports_text_generation(self) -> bool:
return False

@property
def supports_classification(self) -> bool:
return self.classification_head_enabled

def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None:
# Note: This is meant to 1) preallocate the memory by doing a forward pass
# and then just returning the max seqlen since for embeddings we are never generating
Expand Down Expand Up @@ -125,3 +159,17 @@ def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:

cpu_results = embedding.cpu().tolist()
return cpu_results

@tracer.start_as_current_span("classify")
def classify(self, batch: FlashEmbeddingClassificationBatch):
logits: torch.Tensor = self.model.forward(
input_ids=batch.input_ids,
token_type_ids=batch.token_type_ids,
position_ids=batch.position_ids,
cu_seqlens=batch.cu_seqlens,
max_s=batch.max_s,
)
probabilities = torch.nn.functional.softmax(logits, dim=2)
confidence_scores, predictions = torch.max(probabilities, dim=2)
predicted_token_class = [[self.config.id2label[t.item()] for t in prediction] for prediction in predictions]
return predicted_token_class, confidence_scores.cpu().tolist()

0 comments on commit a3f91a6

Please sign in to comment.