Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Apr 4, 2024
1 parent 03664a1 commit b5ae3e8
Showing 1 changed file with 45 additions and 86 deletions.
131 changes: 45 additions & 86 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from lorax_server.models.types import GeneratedText
import torch

from pathlib import Path
from torch import nn
from typing import Type, List, Optional
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.models.bert import BertConfig
from transformers import AutoTokenizer
Expand Down Expand Up @@ -40,39 +38,30 @@ class FlashBatch(ABC):
def __len__(self):
return self.size

def from_pb(self, *args, **kwargs):
return None


class BertEmbeddings:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
self.word_embeddings_weight = (
weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
)
self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
self.token_type_embeddings_weight = (
weights.get_tensor(f"{prefix}.token_type_embeddings.weight") .to(dtype) .to(device))
weights.get_tensor(f"{prefix}.token_type_embeddings.weight").to(dtype).to(device)
)

if config.position_embedding_type == "absolute":
self.position_embeddings_weight = (
weights.get_tensor(f"{prefix}.position_embeddings.weight")
.to(dtype)
.to(device)
weights.get_tensor(f"{prefix}.position_embeddings.weight").to(dtype).to(device)
)
else:
raise NotImplementedError(
"FlashBert only supports absolute position embeddings"
)
raise NotImplementedError("FlashBert only supports absolute position embeddings")

self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps
)
self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps)

def forward(self, input_ids, token_type_ids, position_ids):
inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight)
token_type_embeds = nn.functional.embedding(
token_type_ids, self.token_type_embeddings_weight
)
position_embeds = nn.functional.embedding(
position_ids, self.position_embeddings_weight
)
token_type_embeds = nn.functional.embedding(token_type_ids, self.token_type_embeddings_weight)
position_embeds = nn.functional.embedding(position_ids, self.position_embeddings_weight)

inputs_embeds += position_embeds

Expand All @@ -89,24 +78,14 @@ def __init__(self, prefix, weights, device, dtype, config: BertConfig):
value_weight = weights.get_tensor(f"{prefix}.self.value.weight")
value_bias = weights.get_tensor(f"{prefix}.self.value.bias")

self.qkv_weight = (
torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device)
)
self.qkv_bias = (
torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device)
)
self.qkv_weight = torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device)
self.qkv_bias = torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device)

self.dense_weight = (
weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.dense_bias = (
weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.dense_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
self.dense_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)

self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.output.LayerNorm",
weights=weights,
eps=config.layer_norm_eps
prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps
)

self.head_size = config.hidden_size // config.num_attention_heads
Expand All @@ -117,9 +96,7 @@ def forward(self, hidden_states, cu_seqlens, max_s):
residual = hidden_states

qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight)
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=1
)
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1)

attn_output = torch.empty_like(q)
attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale)
Expand All @@ -136,50 +113,32 @@ def forward(self, hidden_states, cu_seqlens, max_s):

class BertLayer:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
self.attention = BertAttention(
f"{prefix}.attention", weights, device, dtype, config
)
self.attention = BertAttention(f"{prefix}.attention", weights, device, dtype, config)

self.intermediate_weight = (
weights.get_tensor(f"{prefix}.intermediate.dense.weight")
.T.to(dtype)
.to(device)
)
self.intermediate_bias = (
weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)
)
self.intermediate_weight = weights.get_tensor(f"{prefix}.intermediate.dense.weight").T.to(dtype).to(device)
self.intermediate_bias = weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)

act = config.hidden_act
self.intermediate_act_fn = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none",
)
)

self.output_weight = (
weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.output_bias = (
weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.output_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
self.output_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.output.LayerNorm",
weights=weights,
eps=config.layer_norm_eps
prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps
)

def forward(self, hidden_states, cu_seqlens, max_s):
hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s)
residual = hidden_states

hidden_states = torch.addmm(
self.intermediate_bias, hidden_states, self.intermediate_weight
)
hidden_states = torch.addmm(self.intermediate_bias, hidden_states, self.intermediate_weight)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = torch.addmm(
self.output_bias,
Expand All @@ -193,8 +152,7 @@ def forward(self, hidden_states, cu_seqlens, max_s):
class BertEncoder:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
self.layers = [
BertLayer(f"{prefix}.layer.{i}", weights, device, dtype, config)
for i in range(config.num_hidden_layers)
BertLayer(f"{prefix}.layer.{i}", weights, device, dtype, config) for i in range(config.num_hidden_layers)
]

def forward(self, hidden_states, cu_seqlens, max_s):
Expand All @@ -218,11 +176,11 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):

class FlashBert(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
Expand All @@ -239,10 +197,10 @@ def __init__(
config = BertConfig.from_pretrained(model_id)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
filenames,
device,
dtype,
process_group=self.process_group,
)
model = FlashBertModel(weights, device, dtype, config)

Expand All @@ -262,14 +220,17 @@ def __init__(
@property
def batch_type(self) -> Type[FlashBatch]:
return FlashBatch


def warmup(self, batch: FlashBatch, max_new_tokens: int) -> int | None:
# return super().warmup(batch, max_new_tokens)
return 42 # lol

def generate_token(self, batch: FlashBatch) -> torch.Tuple[List[GeneratedText] | FlashBatch | None]:
return None

def forward(self, batch: FlashBatch):
return self.embed(batch)


@tracer.start_as_current_span("embed")
def embed(self, batch: FlashBatch) -> Embedding:
embedding = self.model.forward(
Expand All @@ -281,19 +242,17 @@ def embed(self, batch: FlashBatch) -> Embedding:
)
cpu_results = embedding.view(-1).tolist()

return Embedding(
values=cpu_results[:self.hidden_size]
)

return Embedding(values=cpu_results[: self.hidden_size])

def tokenize_to_batch(self, inputs) -> FlashBatch:
tokens = self.tokenizer(inputs, return_token_type_ids=True)
num_tokens = len(tokens["input_ids"])
position_ids = range(num_tokens)
return FlashBatch(
input_ids=torch.tensor(tokens['input_ids'], dtype=torch.int32, device=self.device),
token_type_ids=torch.tensor(tokens['token_type_ids'], dtype=torch.int32, device=self.device),
input_ids=torch.tensor(tokens["input_ids"], dtype=torch.int32, device=self.device),
token_type_ids=torch.tensor(tokens["token_type_ids"], dtype=torch.int32, device=self.device),
position_ids=torch.tensor(position_ids, dtype=torch.int32, device=self.device),
cu_seqlens=torch.tensor([0, num_tokens], dtype=torch.int32, device=self.device),
max_s=num_tokens,
size=1,
)
)

0 comments on commit b5ae3e8

Please sign in to comment.