diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 09b6c829a..c7912d9a1 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import Optional, Type +from typing import Dict, List, Optional, Tuple, Type from transformers import ( AutoTokenizer, @@ -10,6 +10,11 @@ ) from lorax_server.models.custom_modeling.bloom_modeling import ( + ATTN_DENSE, + ATTN_QKV, + LM_HEAD, + MLP_DENSE_4H_TO_H, + MLP_DENSE_H_TO_4H, BloomForCausalLM, ) from lorax_server.models import CausalLM @@ -21,6 +26,10 @@ Weights, ) from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import AdapterBatchData + +ADAPTER_LAYERS = [ATTN_QKV, ATTN_DENSE, MLP_DENSE_H_TO_4H, MLP_DENSE_4H_TO_H] +ROW_PARALLEL = {ATTN_DENSE, MLP_DENSE_4H_TO_H} class BloomCausalLMBatch(CausalLMBatch): @@ -89,6 +98,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, @@ -98,12 +108,24 @@ def __init__( world_size=world_size, ) + self.dynamic_adapter_loading_enabled = True + + @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch + + @property + def has_adapter_data(self) -> bool: + return True def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, + input_ids, + attention_mask, + position_ids, + past_key_values: Optional = None, + adapter_data: Optional[AdapterBatchData] = None ): outputs = self.model.forward( input_ids=input_ids, @@ -111,7 +133,37 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, use_cache=True, + adapter_data=adapter_data, ) logits = outputs.logits return logits, outputs.past_key_values + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "transformer.h" + for i, layer in enumerate(self.model.transformer.h): + layer_weights[(i, ATTN_QKV)] = (f"{prefix}.{i}.self_attention.query_key_value", layer.self_attention.query_key_value) + layer_weights[(i, ATTN_DENSE)] = (f"{prefix}.{i}.self_attention.dense", layer.self_attention.dense) + + layer_weights[(i, MLP_DENSE_H_TO_4H)] = (f"{prefix}.{i}.mlp.dense_h_to_4h", layer.mlp.dense_h_to_4h) + layer_weights[(i, MLP_DENSE_4H_TO_H)] = (f"{prefix}.{i}.mlp.dense_4h_to_h", layer.mlp.dense_4h_to_h) + + # TODO: make Embedding layers adapter-compatible + # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 6658281c3..927e04c50 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -1,3 +1,4 @@ +from collections import defaultdict import json import torch import inspect @@ -17,6 +18,8 @@ from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights +from lorax_server.utils.segments import SegmentConcatBuilder, find_segments tracer = trace.get_tracer(__name__) @@ -46,7 +49,7 @@ class CausalLMBatch(Batch): stopping_criterias: List[StoppingCriteria] # Adapter metadata for each request - adapter_indices: torch.Tensor + adapter_meta: AdapterBatchMetadata # Metadata used for padding max_input_length: int @@ -87,6 +90,7 @@ def from_pb( padding_right_offset = 0 max_decode_tokens = 0 adapter_indices_list = [] + adapter_set = set() for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) @@ -102,6 +106,7 @@ def from_pb( padding_right_offset, stopping_criteria.max_new_tokens ) adapter_indices_list.append(r.adapter_index) + adapter_set.add(r.adapter_index) adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device) @@ -135,6 +140,9 @@ def from_pb( max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + return cls( batch_id=pb.id, requests=pb.requests, @@ -152,7 +160,12 @@ def from_pb( max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, - adapter_indices=adapter_indices, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) @tracer.start_as_current_span("filter") @@ -173,7 +186,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: all_input_ids = [] max_input_length = 0 - # TODO(travis): adapter indices + adapter_set = set() next_token_choosers = [] stopping_criterias = [] @@ -206,9 +219,12 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: new_padding_right_offset, remaining_decode_tokens ) + adapter_set.add(self.requests[idx].adapter_index) + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached input_ids = self.input_ids[keep_indices] position_ids = self.position_ids[keep_indices] + adapter_indices = self.adapter_meta.adapter_indices[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, -(self.padding_right_offset + max_input_length) : ( @@ -239,6 +255,10 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + device = self.input_ids.device + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids = input_ids @@ -252,6 +272,12 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return self @@ -285,6 +311,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": past_key_values = [] adapter_indices = None + total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) + adapter_set = set() + adapter_segment_builder = SegmentConcatBuilder() + cumulative_adapter_indices_size = 0 + # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes start_index = 0 @@ -319,10 +351,15 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Copy to correct indices input_ids[start_index:end_index] = batch.input_ids - # Create adapter indices - if adapter_indices is None: - adapter_indices = batch.adapter_indices.new_empty((total_batch_size,)) - adapter_indices[start_index:end_index] = batch.adapter_indices + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] + adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + + # Update adapter segments + adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) # Create padded tensor if attention_mask is None: @@ -444,6 +481,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": past_key_values.append([padded_past_keys, padded_past_values]) + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -462,7 +501,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, - adapter_indices=adapter_indices, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) def __len__(self): @@ -523,6 +567,7 @@ def __init__( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, @@ -530,9 +575,15 @@ def __init__( device=device, ) + self.dynamic_adapter_loading_enabled = False + @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch + + @property + def has_adapter_data(self) -> bool: + return False def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode( @@ -540,7 +591,12 @@ def decode(self, generated_ids: List[int]) -> str: ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, + input_ids, + attention_mask, + position_ids, + past_key_values: Optional = None, + adapter_data: Optional[AdapterBatchData] = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { @@ -552,6 +608,8 @@ def forward( } if self.has_position_ids: kwargs["position_ids"] = position_ids + if self.has_adapter_data: + kwargs["adapter_data"] = adapter_data outputs = self.model.forward(**kwargs) return outputs.logits, outputs.past_key_values @@ -563,11 +621,16 @@ def generate_token( # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + # Assign pointers to LoRA weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights) + logits, past = self.forward( batch.input_ids, attention_mask, batch.position_ids, batch.past_key_values, + adapter_data, ) # Results @@ -586,6 +649,8 @@ def generate_token( batch.all_input_ids, ) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) + # For each member of the batch for i, ( request, @@ -684,6 +749,7 @@ def generate_token( batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) + next_adapter_indices[i] = request.adapter_index # We finished all generations in the batch; there is no next batch if stopped: @@ -703,4 +769,6 @@ def generate_token( # Update past key values batch.past_key_values = past + batch.adapter_meta.adapter_indices = next_adapter_indices + return generations, batch diff --git a/server/lorax_server/models/custom_modeling/bloom_modeling.py b/server/lorax_server/models/custom_modeling/bloom_modeling.py index c83436ed3..69224be29 100644 --- a/server/lorax_server/models/custom_modeling/bloom_modeling.py +++ b/server/lorax_server/models/custom_modeling/bloom_modeling.py @@ -33,11 +33,14 @@ from transformers import BloomConfig, PreTrainedModel from lorax_server.utils.layers import ( + TensorParallelAdapterRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelMultiAdapterLinear, TensorParallelRowLinear, TensorParallelHead, ) +from lorax_server.utils.lora import AdapterBatchData CUSTOM_KERNELS_ENABLED = False if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": @@ -61,6 +64,12 @@ "bigscience/bloom", ] +ATTN_QKV = "attn.query_key_value" +ATTN_DENSE = "attn.dense" +MLP_DENSE_H_TO_4H = "mlp.dense_h_to_4h" +MLP_DENSE_4H_TO_H = "mlp.dense_4h_to_h" +LM_HEAD = "lm_head" + def _make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int @@ -231,7 +240,7 @@ def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor class BloomAttention(nn.Module): - def __init__(self, prefix, config: BloomConfig, weights): + def __init__(self, prefix, config: BloomConfig, weights, layer_id): super().__init__() self.pretraining_tp = config.pretraining_tp @@ -262,14 +271,25 @@ def __init__(self, prefix, config: BloomConfig, weights): f"and `num_shards`: {process_group.size()}" ) self.num_heads = self.num_heads // process_group.size() - self.query_key_value = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.query_key_value", - weights=weights, - bias=True, - ) - self.dense = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.dense", weights=weights, bias=True + self.query_key_value = TensorParallelMultiAdapterLinear.load( + TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=True, + ), + layer_id, + [ATTN_QKV], + sizes=None, + process_group=weights.process_group, + ) + self.dense = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense", weights=weights, bias=True + ), + layer_id, + ATTN_DENSE, + process_group=weights.process_group, ) self.attention_dropout = nn.Dropout(config.attention_dropout) @@ -360,10 +380,11 @@ def forward( layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, + adapter_data: Optional[AdapterBatchData] = None, output_attentions: bool = False, ): fused_qkv = self.query_key_value( - hidden_states + hidden_states, adapter_data, ) # [batch_size, seq_length, 3 x hidden_size] batch_size, q_length, _ = fused_qkv.shape @@ -417,7 +438,7 @@ def forward( self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: - output_tensor = self.dense(context_layer) + output_tensor = self.dense(context_layer, adapter_data) # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) output_tensor += residual @@ -430,26 +451,42 @@ def forward( class BloomMLP(nn.Module): - def __init__(self, prefix, config: BloomConfig, weights): + def __init__(self, prefix, config: BloomConfig, weights, layer_id): super().__init__() self.pretraining_tp = config.pretraining_tp self.slow_but_exact = config.slow_but_exact - self.dense_h_to_4h = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True - ) - self.dense_4h_to_h = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + self.dense_h_to_4h = TensorParallelMultiAdapterLinear.load( + TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ), + layer_id, + [MLP_DENSE_H_TO_4H], + sizes=None, + process_group=weights.process_group, + ) + self.dense_4h_to_h = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ), + layer_id, + MLP_DENSE_4H_TO_H, + process_group=weights.process_group, ) self.gelu_impl = torch.nn.GELU(approximate="tanh") self.hidden_dropout = config.hidden_dropout def forward( - self, hidden_states: torch.Tensor, residual: torch.Tensor + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + adapter_data: Optional[AdapterBatchData] = None, ) -> torch.Tensor: - hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states, adapter_data)) - if self.pretraining_tp > 1 and self.slow_but_exact: + if self.pretraining_tp > 1 and self.slow_but_exact and ( + adapter_data is None or adapter_data.max_rank == 0 + ): intermediate_output = torch.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): @@ -460,7 +497,7 @@ def forward( ], ) else: - intermediate_output = self.dense_4h_to_h(hidden_states) + intermediate_output = self.dense_4h_to_h(hidden_states, adapter_data) # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) intermediate_output += residual @@ -480,7 +517,7 @@ def __init__(self, layer_id: int, config: BloomConfig, weights): ) self.num_heads = config.n_head self.self_attention = BloomAttention( - prefix=f"{prefix}.self_attention", config=config, weights=weights + prefix=f"{prefix}.self_attention", config=config, weights=weights, layer_id=layer_id, ) self.post_attention_layernorm = LayerNorm.load( prefix=f"{prefix}.post_attention_layernorm", @@ -488,7 +525,7 @@ def __init__(self, layer_id: int, config: BloomConfig, weights): eps=config.layer_norm_epsilon, ) - self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm ) @@ -502,6 +539,7 @@ def forward( layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, + adapter_data: Optional[AdapterBatchData] = None, output_attentions: bool = False, ): # hidden_states: [batch_size, seq_length, hidden_size] @@ -524,6 +562,7 @@ def forward( alibi=alibi, head_mask=head_mask, use_cache=use_cache, + adapter_data=adapter_data, output_attentions=output_attentions, ) @@ -540,7 +579,7 @@ def forward( residual = attention_output # MLP. - output = self.mlp(layernorm_output, residual) + output = self.mlp(layernorm_output, residual, adapter_data=adapter_data) if use_cache: outputs = (output,) + outputs @@ -669,6 +708,7 @@ def forward( head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + adapter_data: Optional[AdapterBatchData] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -773,6 +813,7 @@ def forward( attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, + adapter_data=adapter_data, output_attentions=output_attentions, alibi=alibi, ) @@ -817,10 +858,12 @@ def __init__(self, config, weights): super().__init__(config) self.transformer = BloomModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, - prefix="word_embeddings", - weights=weights, + self.lm_head = TensorParallelAdapterRowLinear.load( + TensorParallelHead.load( + config, + prefix="word_embeddings", + weights=weights, + ), 0, LM_HEAD, process_group=weights.process_group ) def prepare_inputs_for_generation( @@ -863,6 +906,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, + adapter_data: Optional[AdapterBatchData] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -895,13 +939,14 @@ def forward( head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + adapter_data=adapter_data, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + lm_logits = self.lm_head(hidden_states, adapter_data) loss = None if not return_dict: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 40f3119ac..7cc3112c8 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -671,6 +671,7 @@ def __len__(self): class FlashCausalLM(Model): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, num_layers: int, @@ -682,16 +683,15 @@ def __init__( world_size: int = 1, sliding_window: Optional[int] = None, compile: bool = False, + adapter_id: str = BASE_MODEL_ADAPTER_ID, + dynamic_adapter_loading_enabled: bool = True, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_size = head_size - # This may be set to False in the subclass constructor - self.dynamic_adapter_loading_enabled = True - self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) - super(FlashCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, @@ -700,162 +700,13 @@ def __init__( rank=rank, world_size=world_size, sliding_window=sliding_window, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) self.compile = compile self.model_graph_wrapper: GraphCache = None - self.target_to_layer = self.adapter_target_to_layer() - - @property - def supports_adapter_loading(self) -> bool: - return False - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - return {} - - @property - def adapter_layers(self) -> List[str]: - return [] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 0 - - def is_row_parallel(self, layer_type: str) -> bool: - return False - - def load_adapter(self, adapter_id, adapter_source, adapter_index): - """Physically loads the adapter weights into the model. - - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are merged into the model - weights on the fly. - """ - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") - - # If we are doing dynamic adapter loading, then we need to reset the weights - if adapter_id == self.adapter_id: - return - elif adapter_id != BASE_MODEL_ADAPTER_ID: - logger.info(f"Loading adapter weights into model: {adapter_id}") - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( - self.model_id, adapter_id, adapter_source, weight_names - ) - - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - self.load_batched_adapter_weights( - module_map, adapter_config, adapter_index, layer_name, unused_weight_names - ) - - if len(unused_weight_names) > 0: - logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.adapter_id = adapter_id - - def shard_lora_weights( - self, - weights_a: List[torch.Tensor], - weights_b: List[torch.Tensor], - layer_type: str, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - # [hidden_size, r] - split_dim = 0 if self.is_row_parallel(layer_type) else 1 - weights_a = [ - shard_on_dim(w, dim=split_dim, process_group=self.process_group) - for w in weights_a - ] - - # [r, hidden_size] - weights_b = [ - shard_on_dim(w, dim=1, process_group=self.process_group) - for w in weights_b - ] - - return weights_a, weights_b - - def load_batched_adapter_weights( - self, - module_map: Dict[str, Dict], - adapter_config: LoraConfig, - adapter_index: int, - layer_type: str, - unused_weight_names: Set[str], - ): - nlayers = self.get_num_layers_for_type(layer_type) - lora_a_list = [None] * nlayers - lora_b_list = [None] * nlayers - - for layer_id in range(nlayers): - key = (layer_id, layer_type) - weight_name, layer = self.target_to_layer[key] - - base_weight = layer.base_layer.linear.weight - base_device = base_weight.device - - if weight_name not in module_map: - # There is no LoRA weight for this layer type in the adapter - return - - lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, self.dtype) - - lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, self.dtype) - - scale = adapter_config.lora_alpha / adapter_config.r - - unused_weight_names.discard(lora_a_name) - unused_weight_names.discard(lora_b_name) - - # Merge scaling factor into lora_b due to associativity of matrix multiplication: - # (A * B) * C = A * (B * C) - lora_a_list[layer_id] = lora_a.transpose(0, 1) - lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale - - q_lora_merged = MergedLoraWeights( - *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config, - ) - q_lora_weights = self.batched_lora_weights[layer_type] - q_lora_weights.add_adapter(adapter_index, q_lora_merged) - - def offload_adapter(self, adapter_id, adapter_source, adapter_index): - """Offloads the adapter weights from GPU to CPU or disk.""" - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") - - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - for layer_name in self.adapter_layers: - if layer_name in self.batched_lora_weights: - self.batched_lora_weights[layer_name].remove_adapter(adapter_index) - - self.adapter_id = BASE_MODEL_ADAPTER_ID - @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 1aea62437..fa10dce28 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -75,16 +75,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -97,11 +98,11 @@ def __init__( if config.quantize == "gptq": weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashGPT2ForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashGPT2, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.transformer.h), @@ -112,6 +113,8 @@ def __init__( rank=rank, world_size=world_size, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index e67d22336..3ea9832c0 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -68,16 +68,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -90,11 +91,11 @@ def __init__( if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashLlamaForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -105,6 +106,8 @@ def __init__( rank=rank, world_size=world_size, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index bdf0a33a9..9f4d72994 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -354,16 +354,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -376,11 +377,11 @@ def __init__( if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashMistralForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -392,6 +393,8 @@ def __init__( world_size=world_size, sliding_window=config.sliding_window, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index f1f507f52..5ec54bb10 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -361,16 +361,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -383,11 +384,11 @@ def __init__( if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashMixtralForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashMixtral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -399,6 +400,8 @@ def __init__( world_size=world_size, sliding_window=config.sliding_window, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_neox.py b/server/lorax_server/models/flash_neox.py index 6989d9f15..cba987a69 100644 --- a/server/lorax_server/models/flash_neox.py +++ b/server/lorax_server/models/flash_neox.py @@ -60,6 +60,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashNeoXSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.gpt_neox.layers), diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index 8e9c1d9d6..23a0425f8 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -72,16 +72,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -94,12 +95,12 @@ def __init__( if config.quantize == "gptq": weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashPhiForCausalLM(config, weights) self.config = config torch.distributed.barrier(group=self.process_group) super(FlashPhi, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.transformer.h), @@ -110,6 +111,8 @@ def __init__( rank=rank, world_size=world_size, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index 15679965d..6679be800 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -73,16 +73,17 @@ def __init__( # the adapter weights with the model weights. This also disables dynamic # adapter loading, since the model is now itself initialized with an adapter. merged_weight_filenames = None - self.dynamic_adapter_loading_enabled = True - self.adapter_id = BASE_MODEL_ADAPTER_ID + dynamic_adapter_loading_enabled = True if len(adapter_id) > 0: logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") # Need to pass the adapter source here merged_weight_filenames = create_merged_weight_files( adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source ) - self.dynamic_adapter_loading_enabled = False - self.adapter_id = adapter_id + dynamic_adapter_loading_enabled = False + adapter_id = adapter_id + else: + adapter_id = BASE_MODEL_ADAPTER_ID weights = Weights( filenames, @@ -95,12 +96,12 @@ def __init__( if config.quantize == "gptq": weights._set_gptq_params(model_id) - self.model_id = model_id model = FlashQwenForCausalLM(config, weights) self.config = config torch.distributed.barrier(group=self.process_group) super(FlashQwen, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.transformer.h), @@ -111,6 +112,8 @@ def __init__( rank=rank, world_size=world_size, compile=compile, + adapter_id=adapter_id, + dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, ) @property diff --git a/server/lorax_server/models/flash_rw.py b/server/lorax_server/models/flash_rw.py index 055887e06..3b4176612 100644 --- a/server/lorax_server/models/flash_rw.py +++ b/server/lorax_server/models/flash_rw.py @@ -66,6 +66,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/lorax_server/models/flash_santacoder.py b/server/lorax_server/models/flash_santacoder.py index 88c3a75cb..fd69a4411 100644 --- a/server/lorax_server/models/flash_santacoder.py +++ b/server/lorax_server/models/flash_santacoder.py @@ -70,6 +70,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashSantacoderSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index 6d8fb8933..f1072d664 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -206,6 +206,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/gpt_neox.py b/server/lorax_server/models/gpt_neox.py index ae3a56862..e37b27c91 100644 --- a/server/lorax_server/models/gpt_neox.py +++ b/server/lorax_server/models/gpt_neox.py @@ -67,6 +67,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f782bf294..36d49fe84 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -1,14 +1,19 @@ +from collections import defaultdict import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from loguru import logger +from peft import LoraConfig +from typing import Dict, List, Set, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import InfoResponse -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import BatchedLoraWeights, MergedLoraWeights +from lorax_server.utils.weights import shard_on_dim B = TypeVar("B", bound=Batch) @@ -16,6 +21,7 @@ class Model(ABC): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, @@ -24,7 +30,10 @@ def __init__( rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + adapter_id: str = BASE_MODEL_ADAPTER_ID, + dynamic_adapter_loading_enabled: bool = True, ): + self.model_id = model_id self.model = model.eval() self.tokenizer = tokenizer self.tokenizers = TokenizerManager() @@ -36,6 +45,12 @@ def __init__( self.world_size = world_size self.sliding_window = sliding_window + # This may be set to False in the subclass constructor + self.adapter_id = adapter_id + self.dynamic_adapter_loading_enabled = dynamic_adapter_loading_enabled + self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) + self.target_to_layer = self.adapter_target_to_layer() + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -105,7 +120,151 @@ def check_initialized(self): f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) + @property + def supports_adapter_loading(self) -> bool: + return False + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + return {} + + @property + def adapter_layers(self) -> List[str]: + return [] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 0 + + def is_row_parallel(self, layer_type: str) -> bool: + return False + def load_adapter(self, adapter_id, adapter_source, adapter_index): + """Physically loads the adapter weights into the model. + + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + into model. Otherwise, the adapter weights are merged into the model + weights on the fly. + """ + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if not self.dynamic_adapter_loading_enabled: + if adapter_id == BASE_MODEL_ADAPTER_ID: + return + else: + raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") + + # If we are doing dynamic adapter loading, then we need to reset the weights + if adapter_id == self.adapter_id: + return + elif adapter_id != BASE_MODEL_ADAPTER_ID: + logger.info(f"Loading adapter weights into model: {adapter_id}") + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( + self.model_id, adapter_id, adapter_source, weight_names + ) + + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + self.load_batched_adapter_weights( + module_map, adapter_config, adapter_index, layer_name, unused_weight_names + ) + + if len(unused_weight_names) > 0: + logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + self.adapter_id = adapter_id + + def shard_lora_weights( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + layer_type: str, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + split_dim = 0 if self.is_row_parallel(layer_type) else 1 + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=self.process_group) + for w in weights_a + ] + + # [r, hidden_size] + weights_b = [ + shard_on_dim(w, dim=1, process_group=self.process_group) + for w in weights_b + ] + + return weights_a, weights_b + + def load_batched_adapter_weights( + self, + module_map: Dict[str, Dict], + adapter_config: LoraConfig, + adapter_index: int, + layer_type: str, + unused_weight_names: Set[str], + ): + nlayers = self.get_num_layers_for_type(layer_type) + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = self.target_to_layer[key] + + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, self.dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, self.dtype) + + scale = adapter_config.lora_alpha / adapter_config.r + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + q_lora_merged = MergedLoraWeights( + *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config, + ) + q_lora_weights = self.batched_lora_weights[layer_type] + q_lora_weights.add_adapter(adapter_index, q_lora_merged) + + def offload_adapter(self, adapter_id, adapter_source, adapter_index): + """Offloads the adapter weights from GPU to CPU or disk.""" + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if not self.dynamic_adapter_loading_enabled: + if adapter_id == BASE_MODEL_ADAPTER_ID: + return + else: + raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") + if adapter_id == BASE_MODEL_ADAPTER_ID: return - raise ValueError("This model does not support adapter loading.") + else: + for layer_name in self.adapter_layers: + if layer_name in self.batched_lora_weights: + self.batched_lora_weights[layer_name].remove_adapter(adapter_index) + + self.adapter_id = BASE_MODEL_ADAPTER_ID diff --git a/server/lorax_server/models/mpt.py b/server/lorax_server/models/mpt.py index a2388eb68..cfca9a968 100644 --- a/server/lorax_server/models/mpt.py +++ b/server/lorax_server/models/mpt.py @@ -92,6 +92,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, diff --git a/server/lorax_server/models/opt.py b/server/lorax_server/models/opt.py index 86755c363..6898cc580 100644 --- a/server/lorax_server/models/opt.py +++ b/server/lorax_server/models/opt.py @@ -65,6 +65,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/rw.py b/server/lorax_server/models/rw.py index d8c5f6bfe..4d80f255c 100644 --- a/server/lorax_server/models/rw.py +++ b/server/lorax_server/models/rw.py @@ -60,6 +60,7 @@ def __init__( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/santacoder.py b/server/lorax_server/models/santacoder.py index 67c68fe12..ec1a0a9be 100644 --- a/server/lorax_server/models/santacoder.py +++ b/server/lorax_server/models/santacoder.py @@ -65,6 +65,7 @@ def __init__( ) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 250b8d628..00db4dd2e 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -548,6 +548,7 @@ def __init__( tokenizer.bos_token_id = model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/models/t5.py b/server/lorax_server/models/t5.py index 327dfa3a0..a68118206 100644 --- a/server/lorax_server/models/t5.py +++ b/server/lorax_server/models/t5.py @@ -75,6 +75,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 2e19b3ff7..449c368fe 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -485,15 +485,31 @@ def load(cls, base_layer, layer_id, layer_names, sizes, process_group): def forward(self, input: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor: result = self.base_layer(input) + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) offset = 0 for i, layer_name in enumerate(self.layer_names): start_idx = offset // self.process_group.size() - offset += self.sizes[i] - end_idx = offset // self.process_group.size() + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] result = self.forward_layer_type(result, input, adapter_data, layer_name, start_idx, end_idx) + + if is_3d: + result = result.reshape(prev_shape) return result diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index a6f3fb282..8055a1bb5 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -14,10 +14,11 @@ def batch_type(self): def generate_token(self, batch): raise NotImplementedError - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + model_id = "meta-llama/Llama-2-7b-hf" + tokenizer = AutoTokenizer.from_pretrained(model_id) model = TestModel( - torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") + model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") ) return model