From ddd1c2dcc72b68a5258db49242ce07ec65058e7e Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:16:43 -0800 Subject: [PATCH] Adds support for the Solar architecture (#719) --- router/src/config.rs | 1 + server/lorax_server/models/__init__.py | 11 + .../custom_modeling/flash_solar_modeling.py | 691 ++++++++++++++++++ server/lorax_server/models/flash_causal_lm.py | 41 +- server/lorax_server/models/flash_solar.py | 93 +++ 5 files changed, 818 insertions(+), 19 deletions(-) create mode 100644 server/lorax_server/models/custom_modeling/flash_solar_modeling.py create mode 100644 server/lorax_server/models/flash_solar.py diff --git a/router/src/config.rs b/router/src/config.rs index cf1eb7b78..4f73faf68 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -161,6 +161,7 @@ pub enum Config { PhiMsft, Phi3, Llama, + Solar, Baichuan, Paligemma(Paligemma), Gemma, diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index ea9197fa1..e5300e31c 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -278,6 +278,17 @@ def get_model( **flash_causal_lm_kwargs, ) + if model_type == "solar": + from lorax_server.models.flash_solar import FlashSolar + + return FlashSolar( + model_id, + adapter_id, + adapter_source, + revision, + **flash_causal_lm_kwargs, + ) + if model_type == "gemma": from lorax_server.models.flash_gemma import FlashGemma diff --git a/server/lorax_server/models/custom_modeling/flash_solar_modeling.py b/server/lorax_server/models/custom_modeling/flash_solar_modeling.py new file mode 100644 index 000000000..de9e1a416 --- /dev/null +++ b/server/lorax_server/models/custom_modeling/flash_solar_modeling.py @@ -0,0 +1,691 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +# Flash attention imports +import dropout_layer_norm +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from lorax_server.adapters import AdapterBatchData +from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen +from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA +from lorax_server.utils.layers import ( + MultiAdapterHead, + PositionRotaryEmbedding, + TensorParallelAdapterRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelHead, + TensorParallelMultiAdapterLinear, + TensorParallelRowLinear, + get_linear, +) +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized + +if not HAS_FLASH_ATTN_V2_CUDA: + raise ImportError("Solar model requires flash attn v2") + + +class SolarConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=[12, 20, 32, 44], + bskcn_2=[20, 32], + bskcn_3=[16, 24, 36, 48], + bskcn_4=[28, 40], + bskcn_tv=[0.9, 0.8], + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.bskcn_1 = bskcn_1 + self.bskcn_2 = bskcn_2 + self.bskcn_3 = bskcn_3 + self.bskcn_4 = bskcn_4 + self.bskcn_tv = bskcn_tv + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class SolarRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + SolarRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights, layer_id, head_size): + base_layer = load_attention_multi(config, prefix, weights, head_size) + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + +def load_attention_multi(config, prefix, weights, head_size): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights, head_size) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights, head_size): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + input_scale, weight_scale = None, None + if isinstance(weight, tuple): + weight, input_scale, weight_scale = weight + + if not is_quantized(config.quantize): + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear( + weight, + bias=None, + quantize=config.quantize, + weight_scale=weight_scale, + input_scale=input_scale, + ) + ) + + +class SolarAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + layer_id: int, + ): + super().__init__() + + self.max_past = config.sliding_window if config.sliding_window is not None else -1 + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = self.hidden_size // self.num_heads + + # TODO(Arnav): Solar can also support linear and dynamic ntk rotary embeddings + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + dtype=weights.dtype, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + + if is_fp8_kv(config.quantize): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.fp8_kv = True + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.fp8_kv = False + + self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) + + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + O_PROJ, + process_group=weights.process_group, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def get_query_key_value_weights(self, clone=True): + """Gets the query, key, and value weights from the attention layer. + + If `clone`, then the weights are cloned before being returned. + + NOTE: if not `clone`, then the weights are returned as views, meaning + that changes to the weights will be reflected in the attention layer. + """ + query, key, value = self.query_key_value.base_layer.linear.weight.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], + dim=0, + ) + + if clone: + return query.clone(), key.clone(), value.clone() + return query, key, value + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states, adapter_data) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], + kv_to_cache[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale, + self.fp8_kv, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + k_scale=self.k_scale, + v_scale=self.v_scale, + fp8_kv=self.fp8_kv, + ) + # Decode + else: + attn_output = paged_attention.attention( + query, + kv_cache[0], + kv_cache[1], + self.num_key_value_heads, + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) + + +class SolarMLP(nn.Module): + def __init__(self, prefix, config, weights, layer_id): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", + ) + ) + # Fuse gate and up proj + gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + [GATE_PROJ, UP_PROJ], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ), + layer_id, + DOWN_PROJ, + process_group=weights.process_group, + ) + self.intermediate_size = config.intermediate_size // weights.process_group.size() + + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) + + +class SolarLayer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + self.self_attn = SolarAttention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = SolarMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) + + self.input_layernorm = SolarRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = SolarRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm(attn_output, res) + + mlp_output = self.mlp(normed_attn_res_output, adapter_data) + + return mlp_output, attn_res + + +class SolarModel(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.layers = nn.ModuleList( + [ + SolarLayer( + f"model.layers.{layer_id}" if not prefix else f"{prefix}.model.layers.{layer_id}", + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.norm = SolarRMSNorm( + prefix="model.norm" if not prefix else f"{prefix}.model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = inputs_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids, max_s, hidden_states.dtype) + + residual = None + + bskcn_1 = None + bskcn_2 = None + # Note, we use index 1 instead of index 0 since index 0 is used when training is enabled + bskcn_tv = self.config.bskcn_tv[1] + for i, layer in enumerate(self.layers): + if i in self.config.bskcn_1: + bskcn_1 = hidden_states + if i in self.config.bskcn_2: + bskcn_2 = hidden_states + if i in self.config.bskcn_3: + hidden_states = (bskcn_1 * bskcn_tv).to(hidden_states.device) + hidden_states * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = (bskcn_2 * bskcn_tv).to(hidden_states.device) + hidden_states * (1 - bskcn_tv) + + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + adapter_data, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashSolarForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + self.embed_tokens = TensorParallelEmbedding( + prefix=("model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"), + weights=weights, + ) + + self.model = SolarModel( + prefix=prefix, + config=config, + weights=weights, + ) + + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = MultiAdapterHead.load( + TensorParallelHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ), + 0, + LM_HEAD, + process_group=weights.process_group, + ) + + self.max_past = config.sliding_window + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + skip_lm_head: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(self.max_past, max_s) + seqlen = seqlen.clamp(max=self.max_past) + + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + prefill_cache_indices, + ) + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + + if skip_lm_head: + return hidden_states, None + + logits, speculative_logits = self.lm_head(hidden_states, adapter_data) + return logits, speculative_logits diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 4b1f27ef4..9fdd80d1b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -106,7 +106,7 @@ class FlashCausalLMBatch(Batch): prefilling_mask_tensor: Optional[torch.Tensor] # Prefill metadata tensors to efficiently compute logprobs - # tensor of length b+1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + # tensor of length b+1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill # noqa: E501 cu_seqlen_prefill: Optional[torch.Tensor] # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers # as we only keep SLIDING_WINDOW values instead of the whole tensor @@ -186,13 +186,16 @@ def from_pb( if batch_tokenized_inputs is None: batch_inputs = [] max_truncation = 0 - for r in pb.requests: - inputs = tokenizers.get_inputs(r, tokenizer) + for request in pb.requests: + inputs = tokenizers.get_inputs(request, tokenizer) batch_inputs.append(inputs) - max_truncation = max(max_truncation, r.truncate) + max_truncation = max(max_truncation, request.truncate) - if all(r.HasField("tokenized_inputs") and len(r.tokenized_inputs.ids) > 0 for r in pb.requests): - batch_tokenized_inputs = [r.tokenized_inputs.ids[-max_truncation:] for r in pb.requests] + if all( + request.HasField("tokenized_inputs") and len(request.tokenized_inputs.ids) > 0 + for request in pb.requests + ): + batch_tokenized_inputs = [request.tokenized_inputs.ids[-max_truncation:] for request in pb.requests] else: batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)[ "input_ids" @@ -225,24 +228,24 @@ def from_pb( block_tables_ragged = [] # Parse batch - for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): + for i, (request, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): # request id -> idx in list mapping - requests_idx_mapping[r.id] = i + requests_idx_mapping[request.id] = i - tokenized_input = tokenized_input[-r.truncate :] + tokenized_input = tokenized_input[-request.truncate :] prompt_length = len(tokenized_input) prompt_lengths.append(prompt_length) - cache_length = r.cache_len + cache_length = request.cache_len assert cache_length <= prompt_length, f"Prefix {cache_length} vs input {prompt_length}" if cache_length == prompt_length: assert False, "unreachable" # `chunk_len` is an optional field in the protobuf # It is only set if the model support chunking - if r.HasField("chunk_len"): - input_length = r.chunk_len + if request.HasField("chunk_len"): + input_length = request.chunk_len if cache_length + input_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment @@ -265,9 +268,9 @@ def from_pb( all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - next_token_chooser_parameters.append(r.parameters) + next_token_chooser_parameters.append(request.parameters) - stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + stopping_criteria = StoppingCriteria.from_pb(request.stopping_parameters, tokenizer) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) @@ -279,13 +282,13 @@ def from_pb( block_tokens = prompt_length + max_new_tokens - 1 + speculative_tokens # blocks and slots can be empty (for example in warmup) - if not r.blocks: + if not request.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: - request_blocks = r.blocks - request_slots = r.slots + request_blocks = request.blocks + request_slots = request.slots block_tables.append(request_blocks) block_tables_ragged.extend(request_blocks) @@ -1745,7 +1748,8 @@ def generate_token( # Only save tokens if we are done prefilling for this request batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] @@ -1988,7 +1992,6 @@ def generate_token( ) all_alternative_tokens.append(alternative_tokens) - stop, reason = stopping_criteria( next_token_id, next_token_text, diff --git a/server/lorax_server/models/flash_solar.py b/server/lorax_server/models/flash_solar.py new file mode 100644 index 000000000..28a163e26 --- /dev/null +++ b/server/lorax_server/models/flash_solar.py @@ -0,0 +1,93 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed +from opentelemetry import trace + +from lorax_server.models import FlashCausalLM +from lorax_server.models.custom_modeling.flash_solar_modeling import ( + FlashSolarForCausalLM, + SolarConfig, +) +from lorax_server.utils.lora import ( + DOWN_PROJ, + GATE_PROJ, + K_PROJ, + LM_HEAD, + O_PROJ, + Q_PROJ, + UP_PROJ, + V_PROJ, +) + +tracer = trace.get_tracer(__name__) + + +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] +ROW_PARALLEL = {O_PROJ, DOWN_PROJ} + + +class FlashSolar(FlashCausalLM): + def __init__( + self, + model_id: str, + adapter_id: str, + adapter_source: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + **kwargs, + ): + super().__init__( + model_id=model_id, + model_cls=FlashSolarForCausalLM, + dtype=dtype, + revision=revision, + adapter_id=adapter_id, + adapter_source=adapter_source, + config_cls=SolarConfig, + **kwargs, + ) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + for i, layer in enumerate(self.model.model.layers): + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + + layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) + layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) + layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) + + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL