Skip to content

Commit

Permalink
Added Gemma2 (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 1, 2024
1 parent 2731478 commit c88fa9e
Show file tree
Hide file tree
Showing 12 changed files with 727 additions and 66 deletions.
12 changes: 8 additions & 4 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,14 @@ def load(
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc
# second, iterate over the adapter index for each token and find its location in the `indices` array
batch_indices = torch.tensor([
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
], dtype=torch.int64, device=device)
batch_indices = torch.tensor(
[
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
],
dtype=torch.int64,
device=device,
)

rank_data[rank] = RankSegments(
rank=rank,
Expand Down
14 changes: 11 additions & 3 deletions server/lorax_server/layers/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
HAS_HQQ = True
try:
from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear

HQQLinear.set_backend(HQQBackend.ATEN)

class HQQLinearLayer(HQQLinear):
@property
def weight(self) -> torch.Tensor:
Expand All @@ -16,11 +18,17 @@ def weight(self) -> torch.Tensor:

def get_hqq_linear(quantize, weight, bias=None) -> HQQLinearLayer:
if quantize == "hqq-4bit":
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)
quant_config = BaseQuantizeConfig(
nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16
)
elif quantize == "hqq-3bit":
quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)
quant_config = BaseQuantizeConfig(
nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16
)
elif quantize == "hqq-2bit":
quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)
quant_config = BaseQuantizeConfig(
nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16
)

# init nn.linear from weight and bias
layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None)
Expand Down
16 changes: 15 additions & 1 deletion server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_model(
from lorax_server.models.flash_bert import FlashBert

return FlashBert(model_id, revision=revision, dtype=dtype)

if model_type == "distilbert":
from lorax_server.models.flash_distilbert import FlashDistilBert

Expand Down Expand Up @@ -283,6 +283,20 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type == "gemma2":
from lorax_server.models.flash_gemma2 import FlashGemma2

return FlashGemma2(
model_id,
adapter_id,
adapter_source,
revision,
quantize=quantize,
compile=compile,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if model_type == "cohere":
from lorax_server.models.flash_cohere import FlashCohere

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# https://github.com/huggingface/text-embeddings-inference/blob/cb802a25d43fe6078c715b49652a3bc8a7d5aac8/backends/python/server/text_embeddings_server/models/flash_bert.py



class DistilBertEmbeddings:
def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
Expand Down Expand Up @@ -47,9 +46,7 @@ def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
self.dense_weight = weights.get_tensor(f"{prefix}.attention.out_lin.weight").T.to(dtype).to(device)
self.dense_bias = weights.get_tensor(f"{prefix}.attention.out_lin.bias").to(dtype).to(device)

self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.sa_layer_norm", weights=weights, eps=1e-12
)
self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.sa_layer_norm", weights=weights, eps=1e-12)

self.head_size = config.hidden_size // config.num_attention_heads
self.softmax_scale = self.head_size**-0.5
Expand Down Expand Up @@ -93,9 +90,7 @@ def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):

self.output_weight = weights.get_tensor(f"{prefix}.ffn.lin2.weight").T.to(dtype).to(device)
self.output_bias = weights.get_tensor(f"{prefix}.ffn.lin2.bias").to(dtype).to(device)
self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.output_layer_norm", weights=weights, eps=1e-12
)
self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.output_layer_norm", weights=weights, eps=1e-12)

def forward(self, hidden_states, cu_seqlens, max_s):
hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s)
Expand Down Expand Up @@ -216,4 +211,4 @@ def forward(self, hidden_states, cu_seqlens, max_s):
self.output_weight,
)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states
return hidden_states
Loading

0 comments on commit c88fa9e

Please sign in to comment.