diff --git a/server/lorax_server/adapters/config.py b/server/lorax_server/adapters/config.py index 6bfcf8645..ec696b420 100644 --- a/server/lorax_server/adapters/config.py +++ b/server/lorax_server/adapters/config.py @@ -22,6 +22,7 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: pass diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 07666bd15..aef75fe24 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -9,6 +9,7 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights +from lorax_server.utils.lora import EMBED_TOKENS from lorax_server.utils.sgmv import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, @@ -36,15 +37,22 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: adapter_weight_names = set() module_map = {} for weight_name in weight_names: - lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" - lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if embedding_weight_name in weight_name: + lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A" + lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B" + else: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: continue + # note(ajinkya): popping the weights so that we know which weights are + # can be used as lora weights (supported) and which cannot module_map[weight_name] = { "lora_A": (adapter_weights[lora_a_name], lora_a_name), "lora_B": (adapter_weights[lora_b_name], lora_b_name), @@ -90,6 +98,7 @@ def __init__( weights_a: List[torch.Tensor], weights_b: List[torch.Tensor], adapter_config: LoraConfig, + is_embed: bool = False, ): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 @@ -98,7 +107,8 @@ def __init__( self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + if not is_embed: + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -158,8 +168,10 @@ def load( key = (layer_id, layer_type) weight_name, layer = model.target_to_layer[key] - base_weight = layer.base_layer.linear.weight - base_device = base_weight.device + if EMBED_TOKENS in layer_type: + base_device = layer.base_layer.weight.device + else: + base_device = layer.base_layer.linear.weight.device if weight_name not in module_map: # There is no LoRA weight for this layer type in the adapter @@ -196,13 +208,15 @@ def load( return LoraWeights( *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), - config, + adapter_config=config, + is_embed=(layer_type == EMBED_TOKENS), ) @dataclass class RankSegments: rank: int + adapter_index_map: int lora_a_ptr: torch.Tensor lora_b_ptr: torch.Tensor @@ -242,6 +256,7 @@ def load( meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor], + is_embed: bool, ) -> Optional["BatchLoraWeights"]: adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)} @@ -252,68 +267,36 @@ def load( device = first_weights.weights_a.device segment_indices = meta.segment_indices - lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} - lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - - segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights] - if not segment_ranks: + lora_a, lora_b, adapter_index_configs = {}, {}, {} + max_rank, rank_indices = 0, defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx in adapter_weights: + adapter_weight = adapter_weights[adapter_idx] + adapter_index_configs[adapter_idx] = adapter_weight.adapter_config + max_rank = max(max_rank, adapter_weight.lora_a_r) + rank_indices[adapter_weight.lora_a_r].append(segment_idx) + lora_a[adapter_idx] = adapter_weight.weights_a + lora_b[adapter_idx] = adapter_weight.weights_b + + if not max_rank: return None - max_rank = max(segment_ranks) - if prefill or max_rank > BGMV_MAX_RANK: - use_sgmv = True - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - else: - use_sgmv = False - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - - adapter_index_configs = { - idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights - } - - rank_indices = defaultdict(list) - for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx not in adapter_weights: - continue - rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + use_sgmv = prefill or max_rank > BGMV_MAX_RANK if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + # prefill_head_indices is used to slice the tokens in the batch such + # that we only forward the last token for each request through lm_head + # there can be multiple head_index associated with each adapter segment for head_index in prefill_head_indices: - # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + # j cannot go out of bounds as that would mean there are tokens without segments if head_index < meta.adapter_segments[j]: + # head_index is part of the current adapter + # so increment the current segment end prefill_head_segment_ends[-1] += 1 else: + # head_index in not part of the current adapter + # close the previous segment and start a new one prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) j += 1 @@ -325,40 +308,54 @@ def load( segment_starts = None segment_ends = None batch_indices = None + lora_a_ptr_indices = [] + lora_b_ptr_indices = [] if use_sgmv: - lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + for segment_idx in indices: + adapter_weight = adapter_weights[segment_indices[segment_idx]] + lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr()) + tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: - for i, segment_index in enumerate(indices): - segment_starts[i] = prefill_head_segment_starts[segment_index] - segment_ends[i] = prefill_head_segment_ends[segment_index] + # since prefill_head_indices is present the segment starts and ends + # need to be adjusted according to the number of head tokens in each + for i, segment_idx in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_idx] + segment_ends[i] = prefill_head_segment_ends[segment_idx] else: - # `indices` indexes the `segment_indices` which contains segment wise adapter index - # `lora_a_ptr` contains segment wise pointers to lora weights - # lengths of `lora_a_ptr` and `segment_indices` must be same - # `indices` will be used to slice the `lora_a_ptr` tensor - # first, find the mapping between adapter index and its location in the `indices` array - idx_locs = {} - for loc, idx in enumerate(indices): - # use the idx to find the adapter index - if segment_indices[idx] not in idx_locs: - # save the first location of encountering a particular adapter index - idx_locs[segment_indices[idx]] = loc - # second, iterate over the adapter index for each token and find its location in the `indices` array + adapter_idx_to_pointer_idx = {} + # find out which adapters are present in the segments for this rank + # iterate over each segment index and use it to find adapter index and weights + for segment_idx in indices: + adapter_idx = segment_indices[segment_idx] + adapter_weight = adapter_weights[adapter_idx] + # if the adapter hasn't been seen before, then append its weight pointers + # and save the index to the just added pointers for later + if adapter_idx not in adapter_idx_to_pointer_idx: + lora_a_ptr_indices.append( + (adapter_weight.weights_a if is_embed else adapter_weight.weights_a_t).data_ptr() + ) + lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr()) + adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1 + # for each token in the batch, see if its adapter is present in the segments for this rank + # if present, then store the index of its weight pointers otherwise store -1 batch_indices = torch.tensor([ - idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1 - for idx in meta.adapter_indices.tolist() + adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() ], dtype=torch.int64, device=device) + lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device) + lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device) + rank_data[rank] = RankSegments( rank=rank, + adapter_index_map=indices, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr[indices], - lora_b_ptr=lora_b_ptr[indices], + lora_a_ptr=lora_a_ptr_indices, + lora_b_ptr=lora_b_ptr_indices, segment_starts=segment_starts, segment_ends=segment_ends, indices=batch_indices, diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 476437128..bd65475a9 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -36,6 +36,7 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: # TODO(travis): this isn't technically the ModuleMap structure, make this more generic return adapter_weights, set(weight_names) diff --git a/server/lorax_server/adapters/medusa_lora.py b/server/lorax_server/adapters/medusa_lora.py index 833af0999..e9434f853 100644 --- a/server/lorax_server/adapters/medusa_lora.py +++ b/server/lorax_server/adapters/medusa_lora.py @@ -29,9 +29,18 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[MedusaLoraModuleMap, Set[str]]: - lora_module_map, weight_names = self.lora_config.map_weights_for_model(adapter_weights, weight_names) - medusa_module_map, _ = self.medusa_config.map_weights_for_model(adapter_weights, weight_names) + lora_module_map, weight_names = self.lora_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name + ) + medusa_module_map, _ = self.medusa_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name + ) return MedusaLoraModuleMap(lora_module_map, medusa_module_map), weight_names def load_batched_adapter_weights( diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 0468baaa8..ed4095636 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -6,7 +6,7 @@ import torch from lorax_server.adapters.types import LORA -from lorax_server.utils.lora import LM_HEAD +from lorax_server.utils.lora import EMBED_TOKENS, LM_HEAD @dataclass @@ -82,6 +82,7 @@ def get_data( meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor], + is_embed: bool, ) -> Dict[str, BatchAdapterWeights]: # bucket adapters by batch class adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict) @@ -91,7 +92,7 @@ def get_data( batch_data = {} for batch_type, adapter_weights in adapter_batch_types.items(): - batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices) + batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices, is_embed) if batched_weights is not None: batch_data[batch_type.key()] = batched_weights return batch_data @@ -117,7 +118,7 @@ def from_meta( for k, v in weights.items(): if v.is_empty(): continue - data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None) + data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None, k == EMBED_TOKENS) return AdapterBatchData(meta=meta, data=data, prefill=prefill) def ranks(self) -> Set[int]: diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 5f30a4676..af9c0755a 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, + TensorParallelAdapterRowEmbedding, TensorParallelAdapterRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -43,6 +44,7 @@ ) from lorax_server.utils.lora import ( DOWN_PROJ, + EMBED_TOKENS, GATE_PROJ, K_PROJ, LM_HEAD, @@ -457,7 +459,12 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelAdapterRowEmbedding( + base_layer=TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights), + layer_id=0, + layer_name=EMBED_TOKENS, + process_group=process_group, + ) self.layers = nn.ModuleList( [ FlashLlamaLayer( @@ -488,7 +495,7 @@ def forward( max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.embed_tokens(input_ids, adapter_data) # Get rotary cos and sin for this forward # Avoid to index in each layer diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 337169b6b..1a3d61927 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -17,6 +17,7 @@ ) from lorax_server.utils.lora import ( DOWN_PROJ, + EMBED_TOKENS, GATE_PROJ, K_PROJ, LM_HEAD, @@ -29,8 +30,7 @@ tracer = trace.get_tracer(__name__) -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD, EMBED_TOKENS] ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} @@ -136,18 +136,23 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + layer_weights[(0, EMBED_TOKENS)] = ("model.embed_tokens", self.model.model.embed_tokens) return layer_weights @property def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def embedding_weight_name(self) -> str: + return EMBED_TOKENS + @property def default_traced_adapter_layers(self) -> List[str]: return [Q_PROJ, V_PROJ] def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + return 1 if layer_type == LM_HEAD or layer_type == EMBED_TOKENS else len(self.model.model.layers) def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 3ee5bd55a..d66ed67f5 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -164,6 +164,11 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return [] + @property + def embedding_weight_name(self) -> str: + # setting it to '' will cause matches with any weight name + return 'placeholder value to be initialized by the subclass' + @property def default_traced_adapter_layers(self) -> List[str]: return [] @@ -224,6 +229,7 @@ def load_adapter( adapter_index, weight_names, api_token, + self.embedding_weight_name, self.trust_remote_code, ) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 97322154d..ad66449bd 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -3,6 +3,7 @@ from functools import lru_cache from typing import TYPE_CHECKING, Set, Tuple +from loguru import logger from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -40,11 +41,18 @@ def load_and_merge_adapters( adapter_index: int, weight_names: Tuple[str], api_token: str, + embedding_weight_name: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: if len(adapter_parameters.adapter_ids) == 1: return load_module_map( - model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token, trust_remote_code + model_id, + adapter_parameters.adapter_ids[0], + adapter_source, + weight_names, + api_token, + embedding_weight_name, + trust_remote_code, ) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_source, adapter_index) @@ -132,6 +140,7 @@ def load_module_map( adapter_source: str, weight_names: Tuple[str], api_token: str, + embedding_weight_name: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with @@ -157,5 +166,18 @@ def load_module_map( adapter_weights.update(load_file(filename)) # map the model weights to the relevant adapter weights (LoRA A and B matrices) - module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name, + ) + + # note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be + # supporting all the weights in the adapter which should be an error but for now just logging it + if len(set(adapter_weights.keys()) - set(adapter_weight_names)) > 0: + logger.warning( + f"Adapter {adapter_id} for the model {model_id}" + \ + f" contains unsupported weights: {', '.join(adapter_weights.keys())}" + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index a6e1f3c97..12e6cbd70 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -109,6 +109,7 @@ def get_max_graph_state( rank_data={ MAX_RANK: RankSegments( rank=MAX_RANK, + adapter_index_map=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), @@ -193,6 +194,7 @@ def trace( { max_rank: RankSegments( rank=max_rank, + adapter_index_map=weight_data.rank_data[MAX_RANK].adapter_index_map[:batch_size], lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], indices=weight_data.rank_data[MAX_RANK].indices[:batch_size], diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index dd4bd6b66..3456d9336 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -349,6 +349,152 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return out +class LoraEmbedding(nn.Module): + def __init__(self, layer_id, process_group): + super().__init__() + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if data.use_sgmv: + # Use SGMV for prefill + if lora_a_ptr is not None and lora_b_ptr is not None: + # note(ajinkya): loop through all segments for this rank + # and lookup embeddings in each lora `A` matrix. + v = torch.zeros_like(result[:, :r]) + for i in range(len(rank_segments.segment_starts)): + v[rank_segments.segment_starts[i]:rank_segments.segment_ends[i], :] = ( + torch.nn.functional.embedding( + input[rank_segments.segment_starts[i]:rank_segments.segment_ends[i]], + data.lora_a[rank_segments.adapter_index_map[i]][self.layer_id], + ) + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + if lora_a_ptr is not None and lora_b_ptr is not None: + # note(ajinkya): there's no segmentation in the batch so just loop + # through each sample in the batch, get the corresponding lora `A` + # matrix, and lookup embeddings + v = torch.zeros_like(result[:, :r]) + for i in range(input.shape[0]): + v[i, :] = torch.nn.functional.embedding( + input[i], + data.lora_a[rank_segments.indices[i].item()][self.layer_id] + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1) + layer_result = self.forward_lora(input, data, adapter_index, adapter_mask) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = torch.nn.functional.embedding(input, lora_a) + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelAdapterRowEmbedding(LoraEmbedding): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(layer_id, process_group) + self.base_layer = base_layer + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torch.Tensor: + result = self.base_layer(input) + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type(result, input, adapter_data, self.layer_name, start_idx, end_idx) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out + + try: import dropout_layer_norm diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index f3d3e6f16..effff93c6 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -9,3 +9,4 @@ DOWN_PROJ = "down_proj" LM_HEAD = "lm_head" +EMBED_TOKENS = "embed_tokens" diff --git a/server/tests/adapters/test_medusa.py b/server/tests/adapters/test_medusa.py index ab0abe822..c6fd06f4d 100644 --- a/server/tests/adapters/test_medusa.py +++ b/server/tests/adapters/test_medusa.py @@ -16,7 +16,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM): download_adapter(adapter_id, HUB) module_map, medusa_config, _, _ = load_module_map( - model_id, adapter_id, HUB, tuple(), None + model_id, adapter_id, HUB, tuple(), None, default_causal_lm.embedding_weight_name ) assert isinstance(medusa_config, MedusaConfig) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index de6949711..d30c51c6e 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -9,6 +9,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +from lorax_server.utils.segments import find_segments class FakeAdapterWeights(AdapterWeights): @@ -37,21 +38,43 @@ def load( meta: "AdapterBatchMetadata", prefill: bool, prefill_head_indices: torch.Tensor, + is_embed: bool, ) -> Optional["BatchAdapterWeights"]: return None @pytest.mark.parametrize( - "lora_ranks", + "lora_ranks,adapter_indices,expected", [ - [8, 16], - [32, 64], + ( + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch + { + 8: ( # rank + [0, 2, 4, 6], # expected segment starts + [2, 4, 6, 8], # expected segment ends + [0, 1, 0, 1], # expected adapter indices + ), + 16: ([8], [10], [2]), + } + ), + ( + [4, 8, 16], + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + { + 4: ([0, 4], [2, 6], [0, 0]), + 8: ([2, 6], [4, 8], [1, 1]), + 16: ([8], [10], [2]), + } + ), ], ) -def test_batched_lora_weights(lora_ranks: List[int]): - # batch meta is hardcoded with this assumption below - assert len(lora_ranks) == 2 - +def test_batched_lora_weights( + lora_ranks: List[int], + adapter_indices: List[int], + expected: Dict[int, Tuple[List[int], Tuple[int], Tuple[int]]] +): + num_adapters = len(lora_ranks) batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -68,59 +91,73 @@ def test_batched_lora_weights(lora_ranks: List[int]): batched_weights.add_adapter(idx, weights) assert not batched_weights.is_empty() - assert len(batched_weights.adapter_weights) == 2 + assert len(batched_weights.adapter_weights) == num_adapters + + segments, segment_indices = find_segments(adapter_indices) meta = AdapterBatchMetadata( - adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), - adapter_set={0, 1}, - adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), - segment_indices=[0, 1, 0, 1], + adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_set=set(adapter_indices), + adapter_segments=torch.tensor(segments, dtype=torch.int64), + segment_indices=segment_indices, ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None, is_embed=False).get(LORA) - assert len(data.lora_a) == 2 + assert len(data.lora_a) == num_adapters + assert len(data.lora_b) == num_adapters assert data.lora_a.keys() == meta.adapter_set - assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) - assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) - - assert len(data.lora_b) == 2 assert data.lora_b.keys() == meta.adapter_set - assert data.lora_b[0].shape == (1, lora_ranks[0], h) - assert data.lora_b[1].shape == (1, lora_ranks[1], h) + for i in range(num_adapters): + assert data.lora_a[i].shape == ( + (1, h, lora_ranks[i]) if lora_ranks[i] < MIN_RANK_CUSTOM else (1, lora_ranks[i], h) + ) + assert data.lora_b[i].shape == (1, lora_ranks[i], h) - assert len(data.rank_data) == 2 - assert data.rank_data.keys() == set(lora_ranks) for lora_rank, rd in data.rank_data.items(): assert rd.rank == lora_rank - - # shape in all cases is the number of segments with this rank - assert rd.lora_a_ptr.shape == (2,) - assert rd.lora_b_ptr.shape == (2,) - assert rd.segment_starts.shape == (2,) - assert rd.segment_ends.shape == (2,) - + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][2]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + + expected_segment_starts = torch.tensor( + expected[lora_rank][0], dtype=rd.segment_starts.dtype, device=rd.segment_starts.device + ) + expected_segment_ends = torch.tensor( + expected[lora_rank][1], dtype=rd.segment_ends.dtype, device=rd.segment_ends.device + ) + assert all(rd.segment_ends == expected_segment_ends) + assert all(rd.segment_starts == expected_segment_starts) @pytest.mark.parametrize( "lora_ranks,adapter_indices,expected", [ ( - [8, 8, 16], - [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch { - 8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]) + 8: ( # rank + [0, 1], # expected adapter indices + [0, 0, 1, 1, 0, 0, 1, 1, -1, -1] # expected indices + ), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ( [4, 8, 16], [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], { - 4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), - 8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), + 4: ([0], [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), + 8: ([1], [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ], @@ -128,9 +165,8 @@ def test_batched_lora_weights(lora_ranks: List[int]): def test_batched_lora_weights_decode( lora_ranks: List[int], adapter_indices: List[int], - expected: Dict[int, Tuple[int, List[int]]] + expected: Dict[int, Tuple[List[int], List[int]]] ): - from lorax_server.utils.segments import find_segments batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -153,13 +189,22 @@ def test_batched_lora_weights_decode( ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None, is_embed=False).get(LORA) for lora_rank, rd in data.rank_data.items(): + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][0]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a_t.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b_t.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) - assert rd.lora_a_ptr.shape == (expected[lora_rank][0],) - assert rd.lora_b_ptr.shape == (expected[lora_rank][0],) assert all(rd.indices == expected_indices) + assert rd.segment_starts == None assert rd.segment_ends == None assert rd.tmp_shrink == None @@ -197,6 +242,6 @@ def test_batched_lora_weights_no_segments(): ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None, is_embed=False).get(LORA) print(data)