diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 07666bd15..f718c7fbf 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -348,10 +348,14 @@ def load( # save the first location of encountering a particular adapter index idx_locs[segment_indices[idx]] = loc # second, iterate over the adapter index for each token and find its location in the `indices` array - batch_indices = torch.tensor([ - idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1 - for idx in meta.adapter_indices.tolist() - ], dtype=torch.int64, device=device) + batch_indices = torch.tensor( + [ + idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1 + for idx in meta.adapter_indices.tolist() + ], + dtype=torch.int64, + device=device, + ) rank_data[rank] = RankSegments( rank=rank, diff --git a/server/lorax_server/layers/hqq.py b/server/lorax_server/layers/hqq.py index a2c99719e..9c0d83e11 100644 --- a/server/lorax_server/layers/hqq.py +++ b/server/lorax_server/layers/hqq.py @@ -4,7 +4,9 @@ HAS_HQQ = True try: from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear + HQQLinear.set_backend(HQQBackend.ATEN) + class HQQLinearLayer(HQQLinear): @property def weight(self) -> torch.Tensor: @@ -16,11 +18,17 @@ def weight(self) -> torch.Tensor: def get_hqq_linear(quantize, weight, bias=None) -> HQQLinearLayer: if quantize == "hqq-4bit": - quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16) + quant_config = BaseQuantizeConfig( + nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 + ) elif quantize == "hqq-3bit": - quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16) + quant_config = BaseQuantizeConfig( + nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 + ) elif quantize == "hqq-2bit": - quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16) + quant_config = BaseQuantizeConfig( + nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16 + ) # init nn.linear from weight and bias layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 802fb40f4..ced32acba 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -96,7 +96,7 @@ def get_model( from lorax_server.models.flash_bert import FlashBert return FlashBert(model_id, revision=revision, dtype=dtype) - + if model_type == "distilbert": from lorax_server.models.flash_distilbert import FlashDistilBert @@ -283,6 +283,20 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_type == "gemma2": + from lorax_server.models.flash_gemma2 import FlashGemma2 + + return FlashGemma2( + model_id, + adapter_id, + adapter_source, + revision, + quantize=quantize, + compile=compile, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == "cohere": from lorax_server.models.flash_cohere import FlashCohere diff --git a/server/lorax_server/models/custom_modeling/flash_bert_modeling.py b/server/lorax_server/models/custom_modeling/flash_bert_modeling.py index a1a2f935c..f69f9d681 100644 --- a/server/lorax_server/models/custom_modeling/flash_bert_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_bert_modeling.py @@ -12,7 +12,6 @@ # https://github.com/huggingface/text-embeddings-inference/blob/cb802a25d43fe6078c715b49652a3bc8a7d5aac8/backends/python/server/text_embeddings_server/models/flash_bert.py - class DistilBertEmbeddings: def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig): self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device) @@ -47,9 +46,7 @@ def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig): self.dense_weight = weights.get_tensor(f"{prefix}.attention.out_lin.weight").T.to(dtype).to(device) self.dense_bias = weights.get_tensor(f"{prefix}.attention.out_lin.bias").to(dtype).to(device) - self.layer_norm = FastLayerNorm.load( - prefix=f"{prefix}.sa_layer_norm", weights=weights, eps=1e-12 - ) + self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.sa_layer_norm", weights=weights, eps=1e-12) self.head_size = config.hidden_size // config.num_attention_heads self.softmax_scale = self.head_size**-0.5 @@ -93,9 +90,7 @@ def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig): self.output_weight = weights.get_tensor(f"{prefix}.ffn.lin2.weight").T.to(dtype).to(device) self.output_bias = weights.get_tensor(f"{prefix}.ffn.lin2.bias").to(dtype).to(device) - self.layer_norm = FastLayerNorm.load( - prefix=f"{prefix}.output_layer_norm", weights=weights, eps=1e-12 - ) + self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.output_layer_norm", weights=weights, eps=1e-12) def forward(self, hidden_states, cu_seqlens, max_s): hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s) @@ -216,4 +211,4 @@ def forward(self, hidden_states, cu_seqlens, max_s): self.output_weight, ) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py new file mode 100644 index 000000000..6005ecaa9 --- /dev/null +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -0,0 +1,535 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from lorax_server.adapters.weights import AdapterBatchData +from lorax_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from lorax_server.layers.layernorm import ( + FastRMSNorm, +) +from lorax_server.layers.rotary import PositionRotaryEmbedding +from lorax_server.layers.tensor_parallel import TensorParallelHead +from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.layers import MultiAdapterHead, TensorParallelAdapterRowLinear, TensorParallelMultiAdapterLinear +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) + + +class Gemma2Config(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Gemma2FastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 + weight = weights.get_tensor(f"{prefix}.weight") + 1 + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new + + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.dtype), residual + + +def load_attention(config, prefix, weights, layer_id): + base_layer = load_attention_multi(config, prefix, weights) + head_size = config.head_dim + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + +def load_attention_multi(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq", "marlin"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) + + +class FlashGemma2Attention(torch.nn.Module): + def __init__(self, layer_id: int, prefix: str, config, weights, causal: bool, is_sliding: bool): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + self.causal = causal + if is_sliding: + self.window_size = config.sliding_window + else: + self.window_size = -1 + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + # self.softmax_scale = self.head_size**-0.5 + self.softmax_scale = config.query_pre_attn_scalar**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + + self.query_key_value = load_attention(config, prefix, weights, layer_id) + + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + ): + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + causal=self.causal, + window_size_left=self.window_size, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + + +class Gemma2MLP(nn.Module): + def __init__(self, layer_id, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=("tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"), + ) + ) + # Fuse gate and up proj + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[config.intermediate_size, config.intermediate_size], + process_group=weights.process_group, + ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, + ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() + + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + + +class FlashGemma2Layer(nn.Module): + def __init__(self, layer_id, prefix, config, weights, causal: bool, is_sliding: bool): + super().__init__() + self.self_attn = FlashGemma2Attention( + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + causal=causal, + is_sliding=is_sliding, + ) + self.mlp = Gemma2MLP(layer_id=layer_id, prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.pre_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = Gemma2FastRMSNorm.load( + prefix=f"{prefix}.post_feedforward_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + ) + + # faster post attention rms norm + normed_attn_res_output, _ = self.post_attention_layernorm(attn_output) + normed_attn_res_output = normed_attn_res_output + res + res = normed_attn_res_output + + pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) + mlp_output = self.mlp(pre_normed, adapter_data) + post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) + + return post_hidden_states, normed_attn_res_output + + +class FlashGemma2Model(torch.nn.Module): + def __init__(self, prefix, config, weights, causal: bool): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGemma2Layer( + layer_id, + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + causal=causal, + is_sliding=layer_id % 2 == 0, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = Gemma2FastRMSNorm.load(prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids, max_s, hidden_states.dtype) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemma2ForCausalLM(torch.nn.Module): + def __init__(self, prefix, config, weights, causal: bool): + super().__init__() + + embed_norm = config.hidden_size**0.5 + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemma2Model(prefix=prefix, config=config, weights=weights, causal=causal) + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix=(f"{prefix}.embed_tokens" if config.tie_word_embeddings else f"{prefix}.lm_head"), + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + input_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) + return logits, speculative_logits diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 32b911237..36ce2d8f6 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -445,8 +445,6 @@ def __init__(self, config, weights): ) self.norm = GemmaRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) - self.gradient_checkpointing = False - self.hidden_size = config.hidden_size self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index b86fab7f9..2678dcf03 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -43,10 +43,6 @@ SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None -if torch.cuda.is_available(): - fp8_supported = torch.cuda.get_device_capability()[0] >= 9 or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) -else: - fp8_supported = False tracer = trace.get_tracer(__name__) diff --git a/server/lorax_server/models/flash_gemma2.py b/server/lorax_server/models/flash_gemma2.py new file mode 100644 index 000000000..a0b266461 --- /dev/null +++ b/server/lorax_server/models/flash_gemma2.py @@ -0,0 +1,127 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed +from opentelemetry import trace +from transformers import AutoTokenizer, PretrainedConfig + +from lorax_server.models import FlashCausalLM +from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM +from lorax_server.utils import ( + Weights, + initialize_torch_distributed, + weight_files, +) +from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ + +tracer = trace.get_tracer(__name__) + + +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] +ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + + +class FlashGemma2(FlashCausalLM): + def __init__( + self, + model_id: str, + adapter_id: str, + adapter_source: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + compile: bool = False, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = PretrainedConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + ) + weights._set_config(model_id, config) + + prefix = "" + model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) + + torch.distributed.barrier(group=self.process_group) + super(FlashGemma2, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + compile=compile, + adapter_id=adapter_id, + adapter_source=adapter_source, + trust_remote_code=trust_remote_code, + ) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + for i, layer in enumerate(self.model.model.layers): + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + + layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) + layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) + layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) + + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 18accb3ff..2e6f76461 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -57,6 +57,7 @@ def to_pb(self) -> generate_pb2.GeneratedText: seed=self.seed, ) + @dataclass class PrefillTokens: token_ids: List[int] @@ -142,7 +143,7 @@ def __len__(self) -> int: @classmethod def from_pb( - self, + self, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, tokenizers: TokenizerManager, @@ -157,9 +158,9 @@ def from_pb( max_truncation = max(max_truncation, r.truncate) batch_inputs = tokenizer( - batch_inputs, - return_token_type_ids=True, - truncation=True, + batch_inputs, + return_token_type_ids=True, + truncation=True, max_length=max_truncation, ) @@ -173,8 +174,10 @@ def from_pb( max_s = 0 cumulative_length = 0 - - for i, (r, tokenized_input, token_type_ids) in enumerate(zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids)): + + for i, (r, tokenized_input, token_type_ids) in enumerate( + zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids) + ): tokenized_input = tokenized_input[-r.truncate :] token_type_ids = token_type_ids[-r.truncate :] all_input_ids.append(tokenized_input) @@ -189,7 +192,7 @@ def from_pb( position_ids.append(request_position_ids) cumulative_length += input_length - + if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64) @@ -198,7 +201,7 @@ def from_pb( input_ids = all_input_ids[0] final_token_type_ids = all_token_type_ids[0] position_ids = position_ids[0] - + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device) position_ids = position_ids.to(device) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 63cc39c77..4adbf4338 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -98,7 +98,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): async def Embed(self, request: generate_pb2.EmbedRequest, context): if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") - + batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index dd4bd6b66..2c1a5f043 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -452,40 +452,22 @@ def static(cls, config, dim, base, device, dtype): **rope_scaling, ) elif rope_type == "su": - short_factor = torch.tensor( - rope_scaling["short_factor"], dtype=torch.float32, device=device - ) + short_factor = torch.tensor(rope_scaling["short_factor"], dtype=torch.float32, device=device) short_inv_freq = 1.0 / ( - short_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) - ) - long_factor = torch.tensor( - rope_scaling["long_factor"], dtype=torch.float32, device=device + short_factor * base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) + long_factor = torch.tensor(rope_scaling["long_factor"], dtype=torch.float32, device=device) long_inv_freq = 1.0 / ( - long_factor - * base - ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.float32) - / dim - ) + long_factor * base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) - original_max_position_embeddings = ( - config.original_max_position_embeddings - ) + original_max_position_embeddings = config.original_max_position_embeddings max_position_embeddings = config.max_position_embeddings if max_position_embeddings <= original_max_position_embeddings: scaling_factor = 1.0 else: scale = max_position_embeddings / original_max_position_embeddings - scaling_factor = math.sqrt( - 1 + math.log(scale) / math.log(original_max_position_embeddings) - ) + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) return SuRotaryEmbedding( short_inv_freq=short_inv_freq, @@ -663,7 +645,7 @@ def yarn(self, device, scaling_factor): self.mscale = float( get_mscale(scaling_factor) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation - + class SuRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, @@ -687,11 +669,7 @@ def __init__( def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): + if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: self._seq_len_cached = seqlen if seqlen > self.original_max_position_embeddings: inv_freq = self.long_inv_freq diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index c4a1a28e2..7f2ca6cea 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -14,11 +14,14 @@ f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) -if torch.cuda.is_available(): - # TODO(travis): fix for CUDA 8.9 (Lovelace) - fp8_supported = torch.cuda.get_device_capability()[0] >= 9 #or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) -else: - fp8_supported = False +# TODO(travis): fix for CUDA 8.9 (Lovelace) and 9.0 (Hopper) +# if torch.cuda.is_available(): +# fp8_supported = ( +# torch.cuda.get_device_capability()[0] >= 9 +# ) # or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) +# else: +fp8_supported = False + def reshape_and_cache( key: torch.Tensor,