From e4bad4da0d0d89af3d0afb9c9fede2eec43550ac Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 26 Mar 2024 21:32:06 -0700 Subject: [PATCH] Refactor adapter interface to support adapters other than LoRA (e.g., speculative decoding) (#359) --- server/lorax_server/adapters/__init__.py | 33 ++ server/lorax_server/adapters/config.py | 34 ++ server/lorax_server/adapters/lora.py | 278 +++++++++++++++ server/lorax_server/adapters/medusa.py | 17 + server/lorax_server/adapters/types.py | 2 + server/lorax_server/adapters/weights.py | 100 ++++++ server/lorax_server/models/bloom.py | 2 +- server/lorax_server/models/causal_lm.py | 2 +- .../models/custom_modeling/bloom_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 3 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 3 +- .../custom_modeling/flash_mistral_modeling.py | 8 +- .../custom_modeling/flash_mixtral_modeling.py | 3 +- .../custom_modeling/flash_phi_modeling.py | 3 +- .../custom_modeling/flash_qwen2_modeling.py | 3 +- .../custom_modeling/flash_qwen_modeling.py | 3 +- server/lorax_server/models/flash_causal_lm.py | 2 +- server/lorax_server/models/flash_mistral.py | 325 +----------------- server/lorax_server/models/flash_mixtral.py | 3 +- server/lorax_server/models/model.py | 81 +---- server/lorax_server/server.py | 9 +- server/lorax_server/utils/adapter.py | 85 ++--- server/lorax_server/utils/graph.py | 49 +-- server/lorax_server/utils/layers.py | 19 +- server/lorax_server/utils/lora.py | 209 +---------- .../lorax_server/utils/merges/strategies.py | 10 +- server/lorax_server/utils/sources/hub.py | 14 +- server/lorax_server/utils/sources/local.py | 20 +- server/lorax_server/utils/sources/s3.py | 14 + server/lorax_server/utils/sources/source.py | 31 +- server/tests/utils/test_adapter.py | 4 +- server/tests/utils/test_lora.py | 14 +- 33 files changed, 658 insertions(+), 729 deletions(-) create mode 100644 server/lorax_server/adapters/__init__.py create mode 100644 server/lorax_server/adapters/config.py create mode 100644 server/lorax_server/adapters/lora.py create mode 100644 server/lorax_server/adapters/medusa.py create mode 100644 server/lorax_server/adapters/types.py create mode 100644 server/lorax_server/adapters/weights.py diff --git a/server/lorax_server/adapters/__init__.py b/server/lorax_server/adapters/__init__.py new file mode 100644 index 000000000..22c6188cd --- /dev/null +++ b/server/lorax_server/adapters/__init__.py @@ -0,0 +1,33 @@ +import json +from pathlib import Path +from typing import Optional + +from lorax_server.adapters.config import AdapterConfig +from lorax_server.adapters.lora import LoraConfig +# from lorax_server.adapters.medusa import MedusaConfig +from lorax_server.adapters.weights import AdapterBatchData, AdapterBatchMetadata + + +def load_adapter_config( + config_path: Optional[Path], + adapter_config_path: Optional[Path], + api_token: str, +) -> AdapterConfig: + if adapter_config_path is not None and adapter_config_path.exists(): + return LoraConfig.load(str(adapter_config_path.parent), api_token) + + # TODO(travis): medusa + # if config_path is not None and config_path.exists(): + # config = json.load(config_path.open()) + # if "medusa_num_heads" in config: + # return MedusaConfig.load(config) + + raise ValueError(f"No valid adapter config file found: " + f"tried {adapter_config_path} and {config_path}") + + +__all__ = [ + "AdapterBatchData", + "AdapterBatchMetadata", + "load_adapter_config", +] diff --git a/server/lorax_server/adapters/config.py b/server/lorax_server/adapters/config.py new file mode 100644 index 000000000..df5a99a1e --- /dev/null +++ b/server/lorax_server/adapters/config.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch + +from lorax_server.adapters.weights import AdapterWeights + +if TYPE_CHECKING: + from server.lorax_server.models.model import Model + + +ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] + + +@dataclass +class AdapterConfig(ABC): + base_model_name_or_path: str + + @abstractmethod + def map_weights_for_model( + self, adapter_weights: Dict, weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + pass + + @abstractmethod + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + pass diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py new file mode 100644 index 000000000..584e7e1e4 --- /dev/null +++ b/server/lorax_server/adapters/lora.py @@ -0,0 +1,278 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union + +import torch +from peft import LoraConfig as _LoraConfig +from torch.distributed import ProcessGroup + +from lorax_server.adapters.config import AdapterConfig, ModuleMap +from lorax_server.adapters.types import LORA +from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights +from lorax_server.utils.sgmv import MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank, pad_rank + +if TYPE_CHECKING: + from lorax_server.models.model import Model + +EMPTY_TENSOR = torch.tensor([]) + + +@dataclass +class LoraConfig(AdapterConfig): + r: int + target_modules: Optional[Union[List[str], str]] + fan_in_fan_out: bool + lora_alpha: int + use_rslora: bool + + def map_weights_for_model( + self, adapter_weights: Dict, weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + adapter_weight_names = set() + module_map = {} + for weight_name in weight_names: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: + continue + + module_map[weight_name] = { + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), + } + adapter_weight_names.add(lora_a_name) + adapter_weight_names.add(lora_b_name) + return module_map, adapter_weight_names + + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + return LoraWeights.load( + self, + model, + module_map, + layer_type, + unused_weight_names, + ) + + @classmethod + def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": + hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) + return cls( + base_model_name_or_path=hf_config.base_model_name_or_path, + r=hf_config.r, + target_modules=hf_config.target_modules, + fan_in_fan_out=hf_config.fan_in_fan_out, + lora_alpha=hf_config.lora_alpha, + use_rslora=hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False, + ) + + +class LoraWeights(AdapterWeights): + """LoRA weights for a single adapter merged across all layers.""" + + def __init__( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + adapter_config: LoraConfig, + ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + + # [num_layers, hidden_size, r] + weights_a = [ + orient_for_rank(w, w.size(1)).contiguous() + for w in weights_a + ] + self.weights_a = torch.stack(weights_a) + + # [num_layers, r, hidden_size] + self.weights_b = torch.stack(weights_b) + + self.adapter_config = adapter_config + + @classmethod + def get_batch_type(cls) -> BatchAdapterWeights: + return BatchLoraWeights + + @classmethod + def load( + cls, + config: LoraConfig, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + nlayers = model.get_num_layers_for_type(layer_type) + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = model.target_to_layer[key] + + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return None + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, model.dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, model.dtype) + + scale = get_scaling_factor( + config.lora_alpha, + config.r, + uses_rslora=config.use_rslora, + ) + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + # pad lora ranks to be compatible with sgmv + lora_a_list = [pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank + + return LoraWeights( + *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), config, + ) + + +@dataclass +class RankSegments: + rank: int + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor + lora_a_ptr: torch.Tensor + lora_b_ptr: torch.Tensor + segment_starts: torch.Tensor + segment_ends: torch.Tensor + + +@dataclass +class BatchLoraWeights(BatchAdapterWeights): + lora_a: Dict[int, torch.Tensor] + lora_b: Dict[int, torch.Tensor] + adapter_index_configs: Dict[int, LoraConfig] + rank_data: Dict[int, RankSegments] + + def has_adapter(self, adapter_index: int) -> bool: + return adapter_index in self.adapter_index_configs + + def can_vectorize(self, pg: ProcessGroup) -> bool: + return all( + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + for rank_data in self.rank_data.values() + ) + + @classmethod + def key(cls) -> str: + return LORA + + @classmethod + def load(self, adapter_weights: Dict[int, LoraWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights": + first_weights = list(adapter_weights.values())[0] + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + rank_data = {} + for rank, indices in rank_indices.items(): + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors( + lora_a_ptr_indices.size(0), + rank, + device + ) + + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr_indices, + lora_b_ptr=lora_b_ptr[indices], + segment_starts=meta.adapter_segments[indices], + segment_ends=meta.adapter_segments[[i+1 for i in indices]], + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + ) + + +def get_scaling_factor( + lora_alpha: int, + r: int, + uses_rslora: bool = False, +) -> float: + """Computes the scaling factor for the lora weights.""" + if uses_rslora: + return lora_alpha / (r ** 0.5) + return lora_alpha / r diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py new file mode 100644 index 000000000..93e25e5bb --- /dev/null +++ b/server/lorax_server/adapters/medusa.py @@ -0,0 +1,17 @@ +# @dataclass +# class MedusaConfig(AdapterConfig): +# medusa_num_heads: int +# medusa_num_layers: int + +# def map_weights_for_model( +# self, adapter_weights: Dict, weight_names: Tuple[str], +# ) -> Tuple[ModuleMap, Set[str]]: +# return adapter_weights, set(weight_names) + +# @classmethod +# def load(cls, config: dict) -> "MedusaConfig": +# return cls( +# base_model_name_or_path=config["base_model_name_or_path"], +# medusa_num_heads=config["medusa_num_heads"], +# medusa_num_layers=config["medusa_num_layers"], +# ) diff --git a/server/lorax_server/adapters/types.py b/server/lorax_server/adapters/types.py new file mode 100644 index 000000000..346db10d7 --- /dev/null +++ b/server/lorax_server/adapters/types.py @@ -0,0 +1,2 @@ +LORA = "lora" +# MEDUSA = "medusa" diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py new file mode 100644 index 000000000..97db46f56 --- /dev/null +++ b/server/lorax_server/adapters/weights.py @@ -0,0 +1,100 @@ +from abc import ABC, abstractclassmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Set, Type + +import torch + +from lorax_server.adapters.types import LORA + + +@dataclass +class AdapterBatchMetadata: + # [batch_size] + adapter_indices: torch.Tensor + + # [num_adapters] + adapter_set: Set[int] + + # [num_segments + 1] + adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] + segment_indices: List[int] + + +class AdapterWeights(ABC): + @abstractclassmethod + def get_batch_type(cls) -> "BatchAdapterWeights": + pass + + +class BatchAdapterWeights(ABC): + @abstractclassmethod + def key(self) -> str: + pass + + @abstractclassmethod + def load(self, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights": + pass + + +class LayerAdapterWeights: + """Adapter weights that apply to a particular layer.""" + + def __init__(self): + self.adapter_weights: Dict[int, AdapterWeights] = {} + + def add_adapter(self, adapter_idx: int, weights: AdapterWeights): + self.adapter_weights[adapter_idx] = weights + + def remove_adapter(self, adapter_idx: int): + if adapter_idx not in self.adapter_weights: + return + del self.adapter_weights[adapter_idx] + + def is_empty(self) -> bool: + return len(self.adapter_weights) == 0 + + def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]: + # bucket adapters by batch class + adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict) + for adapter_index, adapter_weights in self.adapter_weights.items(): + adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights + + batch_data = {} + for batch_type, adapter_weights in adapter_batch_types.items(): + batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta) + return batch_data + + +@dataclass +class AdapterBatchData: + meta: AdapterBatchMetadata + + # layer type -> adapter type -> batch weight data + data: Dict[str, Dict[str, BatchAdapterWeights]] + + @staticmethod + def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData": + data = {} + for k, v in weights.items(): + if v.is_empty(): + continue + data[k] = v.get_data(meta) + return AdapterBatchData(meta=meta, data=data) + + def ranks(self) -> Set[int]: + # TODO(travis): refactor to be less coupled to lora implementation + return set( + rank_data.rank + for layer_data in self.data.values() + for rank_data in layer_data.get(LORA, []).rank_data.values() + ) + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 179acb045..020492415 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -26,7 +26,7 @@ Weights, ) from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.lora import AdapterBatchData +from lorax_server.adapters import AdapterBatchData ADAPTER_LAYERS = [ATTN_QKV, ATTN_DENSE, MLP_DENSE_H_TO_4H, MLP_DENSE_4H_TO_H] ROW_PARALLEL = {ATTN_DENSE, MLP_DENSE_4H_TO_H} diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 2a543fd33..6a38e2e90 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -15,7 +15,7 @@ ) from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata +from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/lorax_server/models/custom_modeling/bloom_modeling.py b/server/lorax_server/models/custom_modeling/bloom_modeling.py index 69224be29..22d9a9fc3 100644 --- a/server/lorax_server/models/custom_modeling/bloom_modeling.py +++ b/server/lorax_server/models/custom_modeling/bloom_modeling.py @@ -40,7 +40,7 @@ TensorParallelRowLinear, TensorParallelHead, ) -from lorax_server.utils.lora import AdapterBatchData +from lorax_server.adapters import AdapterBatchData CUSTOM_KERNELS_ENABLED = False if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index eae841837..b114dcadb 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -24,6 +24,7 @@ # Flash attention imports import dropout_layer_norm +from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( @@ -36,7 +37,7 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData +from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ class GemmaConfig(PretrainedConfig): diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index cbe6bcd85..ad738afde 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -41,7 +41,7 @@ PositionRotaryEmbedding, get_linear, ) -from lorax_server.utils.lora import AdapterBatchData +from lorax_server.adapters import AdapterBatchData ATTN_C_ATTN = "attn.c_attn" ATTN_C_PROJ = "attn.c_proj" diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 67ffb7deb..c975f44c0 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -29,6 +29,7 @@ # Flash attention imports import dropout_layer_norm +from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( @@ -41,7 +42,7 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData +from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ class LlamaConfig(PretrainedConfig): diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 9b51e6443..b43b6a7cd 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -29,6 +29,7 @@ # Flash attention imports import dropout_layer_norm +from lorax_server.adapters import AdapterBatchData from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2 from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn @@ -42,7 +43,7 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData +from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ if not HAS_FLASH_ATTN_V2: raise ImportError("Mistral model requires flash attn v2") @@ -535,6 +536,7 @@ def __init__(self, config, weights): prefix="lm_head", weights=weights, ), 0, LM_HEAD, process_group=weights.process_group) + self.max_past = config.sliding_window if self.max_past is None: raise ValueError("max_past cannot be None") @@ -552,7 +554,7 @@ def forward( adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] @@ -577,4 +579,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states, adapter_data) - return logits \ No newline at end of file + return logits diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 6a3d8b1f0..75b5971e8 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -45,7 +45,8 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import AdapterBatchData, LM_HEAD +from lorax_server.adapters import AdapterBatchData +from lorax_server.utils.lora import LM_HEAD if not HAS_FLASH_ATTN_V2: raise ImportError("Mixtral model requires flash attn v2") diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index d11921179..290d7bfff 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -17,6 +17,7 @@ from transformers.models.phi import PhiConfig from typing import Optional, List, Tuple +from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( @@ -29,7 +30,7 @@ PositionRotaryEmbedding, TensorParallelHead, ) -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData +from lorax_server.utils.lora import LM_HEAD ATTN_Q_PROJ = "self_attn.q_proj" diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index bf694abef..115b71dc8 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -15,6 +15,7 @@ # Flash attention imports import dropout_layer_norm +from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( @@ -27,7 +28,7 @@ TensorParallelHead, get_linear, ) -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData +from lorax_server.utils.lora import LM_HEAD ATTN_Q_PROJ = "self_attn.q_proj" diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index e7162049e..091b9da5b 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -15,6 +15,7 @@ # Flash attention imports import dropout_layer_norm +from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( @@ -26,7 +27,7 @@ PositionRotaryEmbedding, TensorParallelHead, ) -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData +from lorax_server.utils.lora import LM_HEAD ATTN_C_ATTN = "attn.c_attn" diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 7f8355483..9b7112d16 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -30,7 +30,7 @@ from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.dist import MEMORY_FRACTION from lorax_server.utils.graph import GraphCache -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata +from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.state import warmup_mode from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index aabda653d..d480697e3 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -1,23 +1,12 @@ -import json -import math import torch import torch.distributed -import numpy as np - -from dataclasses import dataclass from loguru import logger from opentelemetry import trace -from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple -from lorax_server.pb import generate_pb2 from lorax_server.models import FlashCausalLM -from lorax_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE -from lorax_server.models.cache_manager import ( - get_cache_manager, -) from lorax_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, MistralConfig, @@ -27,287 +16,16 @@ initialize_torch_distributed, weight_files, Weights, - HeterogeneousNextTokenChooser, - StoppingCriteria, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata -from lorax_server.utils.segments import find_segments -from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ tracer = trace.get_tracer(__name__) -# Will be set in init -SLIDING_WINDOW: Optional[int] = None -SLIDING_WINDOW_BLOCKS: Optional[int] = None - ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} -# Adds windowing logic to FlashCausalLMBatch -@dataclass -class FlashMistralBatch(FlashCausalLMBatch): - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] = None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - - batch_inputs = [] - max_truncation = 0 - for r in pb.requests: - inputs = tokenizers.get_inputs(r, tokenizer) - batch_inputs.append(inputs) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - - position_ids = [] - cu_seqlen_prefill = [0] - needed_blocks_slots = [] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - next_token_chooser_parameters = [] - stopping_criterias = [] - # TODO(geoffrey): re-add top_n_tokens functionality in a separate PR - # top_n_tokens = [] - - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_max_length = 0 - prefill_out_cumulative_length = 0 - - blocks = 0 - max_seqlen = 0 - max_length = 0 - max_blocks = 0 - - # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): - # request id -> idx in list mapping - requests_idx_mapping[r.id] = i - - tokenized_input = tokenized_input[-r.truncate :] - - input_length = len(tokenized_input) - input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) - - all_input_ids.append(tokenized_input) - - # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - - next_token_chooser_parameters.append(r.parameters) - - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - max_new_tokens = stopping_criteria.max_new_tokens - stopping_criterias.append(stopping_criteria) - # top_n_tokens.append(r.top_n_tokens) - - adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) - - # Paged attention - # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 - - # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = min( - math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS - ) - blocks += needed_blocks - - needed_blocks_slots.append((needed_blocks, total_tokens)) - start_slots.append(cumulative_max_length) - - request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - - # Update - cumulative_length += input_length - cumulative_max_length += total_tokens - max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) - - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - - request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) - for r in pb.requests - ] - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, request_tokenizers, dtype, device - ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) - for i, input_ids in enumerate(all_input_ids): - all_input_ids_tensor[i, : len(input_ids)] = input_ids - - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) - - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = prefill_cache_indices.to(device) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - # top_n_tokens_tensor = torch.tensor( - # top_n_tokens, device=device, dtype=torch.int64 - # ) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - start_slots=start_slots, - slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - # top_n_tokens=top_n_tokens, - # top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, - max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), - prefill_cache_indices=prefill_cache_indices, - ) - - class FlashMistral(FlashCausalLM): def __init__( self, @@ -320,9 +38,6 @@ def __init__( dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -346,10 +61,6 @@ def __init__( if config.sliding_window is None: config.sliding_window = config.max_position_embeddings - # Set context windows - SLIDING_WINDOW = config.sliding_window - SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -404,38 +115,6 @@ def __init__( @property def supports_adapter_loading(self) -> bool: return True - - @property - def batch_type(self) -> Type[FlashMistralBatch]: - return FlashMistralBatch - - def forward(self, batch: FlashMistralBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: - prefill = batch.cu_seqlen_prefill is not None - model = self.model - if ( - self.model_graph_wrapper is not None and - not prefill and - self.model_graph_wrapper.can_use_graph(batch, adapter_data) - ): - model = self.model_graph_wrapper - - # Model Forward - logits = model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, - adapter_data=adapter_data, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=batch.prefill_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index eef21c749..76772ce90 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -13,6 +13,7 @@ from typing import Dict, List, Optional, Tuple, Type from lorax_server.pb import generate_pb2 +from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.models import FlashCausalLM from lorax_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from lorax_server.models.cache_manager import ( @@ -38,7 +39,7 @@ StoppingCriteria, ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata +from lorax_server.utils.lora import LM_HEAD from lorax_server.utils.segments import find_segments from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 1c611afcf..48505a4ed 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -4,21 +4,17 @@ from abc import ABC, abstractmethod from loguru import logger -from peft import LoraConfig -from typing import Dict, List, Set, Tuple, Optional, TypeVar, Type +from typing import Dict, List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse from lorax_server.utils.adapter import ( BASE_MODEL_ADAPTER_ID, - get_scaling_factor, load_and_merge_adapters, - uses_rslora, ) from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.lora import BatchedLoraWeights, MergedLoraWeights -from lorax_server.utils.sgmv import pad_rank +from lorax_server.adapters.weights import LayerAdapterWeights from lorax_server.utils.weights import shard_on_dim B = TypeVar("B", bound=Batch) @@ -53,7 +49,7 @@ def __init__( # This may be set to False in the subclass constructor self.dynamic_adapter_loading_enabled = dynamic_adapter_loading_enabled - self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) + self.batched_lora_weights: Dict[str, LayerAdapterWeights] = defaultdict(LayerAdapterWeights) self.target_to_layer = self.adapter_target_to_layer() self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -181,9 +177,18 @@ def load_adapter( unused_weight_names = adapter_weight_names.copy() for layer_name in self.adapter_layers: - self.load_batched_adapter_weights( - module_map, adapter_config, adapter_index, layer_name, unused_weight_names + adapter_weights = adapter_config.load_batched_adapter_weights( + self, + module_map, + layer_name, + unused_weight_names ) + + if adapter_weights is None: + continue + + batched_weights = self.batched_lora_weights[layer_name] + batched_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: logger.warning(f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}") @@ -213,64 +218,6 @@ def shard_lora_weights( ] return weights_a, weights_b - - def load_batched_adapter_weights( - self, - module_map: Dict[str, Dict], - adapter_config: LoraConfig, - adapter_index: int, - layer_type: str, - unused_weight_names: Set[str], - ): - nlayers = self.get_num_layers_for_type(layer_type) - lora_a_list = [None] * nlayers - lora_b_list = [None] * nlayers - - for layer_id in range(nlayers): - key = (layer_id, layer_type) - weight_name, layer = self.target_to_layer[key] - - base_weight = layer.base_layer.linear.weight - base_device = base_weight.device - - if weight_name not in module_map: - # There is no LoRA weight for this layer type in the adapter - return - - lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, self.dtype) - - lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, self.dtype) - - scale = get_scaling_factor( - adapter_config.lora_alpha, - adapter_config.r, - uses_rslora=uses_rslora(adapter_config), - ) - - unused_weight_names.discard(lora_a_name) - unused_weight_names.discard(lora_b_name) - - # Merge scaling factor into lora_b due to associativity of matrix multiplication: - # (A * B) * C = A * (B * C) - lora_a_list[layer_id] = lora_a.transpose(0, 1) - lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale - - # pad lora ranks to be compatible with sgmv - lora_a_list = [pad_rank(w, dim=1, world_size=self.world_size) for w in lora_a_list] - lora_b_list = [pad_rank(w, dim=0, world_size=self.world_size) for w in lora_b_list] - - if lora_a_list: - # update rank if it was padded - padded_rank = lora_a_list[0].size(1) - adapter_config.r = padded_rank - - q_lora_merged = MergedLoraWeights( - *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config, - ) - q_lora_weights = self.batched_lora_weights[layer_type] - q_lora_weights.add_adapter(adapter_index, q_lora_merged) def offload_adapter( self, diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index a9506c813..6aff3f372 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -150,18 +150,15 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co # Quick auth check on the repo against the token HfApi(token=api_token).model_info(adapter_id, revision=None) - # fail fast if ID is not an adapter (i.e. it is a full model) - # TODO(geoffrey): do this for S3– can't do it this way because the - # files are not yet downloaded locally at this point. - config_path = get_config_path(adapter_id, adapter_source) - PeftConfig.from_pretrained(config_path, token=api_token) + # fail fast if ID is not an adapter (i.e. it is a full model) + source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) + source.load_config() _download_weights( adapter_id, source=adapter_source, api_token=api_token ) # Calculate size of adapter to be loaded - source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) adapter_bytes += source.get_weight_bytes() adapter_memory_size = self.model.adapter_memory_size() diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 89cb51549..3c3b7b4c6 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -3,13 +3,12 @@ from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import List, Dict, Set, Tuple +from typing import TYPE_CHECKING, List, Dict, Set, Tuple import warnings import torch from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from loguru import logger -from peft import LoraConfig from peft.utils import transpose from safetensors.torch import load_file, save_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -18,13 +17,14 @@ from lorax_server.pb import generate_pb2 from lorax_server.utils.sources import get_model_source, get_config_path, weight_files -from lorax_server.utils.merges.strategies import merge_adapters +from lorax_server.utils.merges.strategies import merge_adapters +from lorax_server.adapters.lora import get_scaling_factor - -BASE_MODEL_ADAPTER_ID = "__base_model__" +if TYPE_CHECKING: + from lorax_server.adapters.config import AdapterConfig, ModuleMap -ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] +BASE_MODEL_ADAPTER_ID = "__base_model__" @dataclass @@ -50,7 +50,7 @@ def load_and_merge_adapters( adapter_index: int, weight_names: Tuple[str], api_token: str, -) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: if len(adapter_parameters.adapter_ids) == 1: return load_module_map( model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token @@ -66,7 +66,7 @@ def _load_and_merge( adapter_params: AdapterParametersContainer, weight_names: Tuple[str], api_token: str, -) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters adapters_to_merge = [] @@ -92,7 +92,7 @@ def _load_and_merge( return module_map, adapter_config, merged_weight_names, tokenizer -def check_architectures(model_id: str, adapter_id: str, adapter_config: LoraConfig, api_token: str): +def check_architectures(model_id: str, adapter_id: str, adapter_config: "AdapterConfig", api_token: str): try: expected_config = AutoConfig.from_pretrained(model_id, token=api_token) model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path, token=api_token) @@ -121,12 +121,12 @@ def load_module_map( adapter_source: str, weight_names: Tuple[str], api_token: str, -) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with # lorax_server/utils/adapter.py::create_merged_weight_files source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) config_path = get_config_path(adapter_id, adapter_source) - adapter_config = LoraConfig.from_pretrained(config_path, token=api_token) + adapter_config = source.load_config() if adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, api_token) @@ -143,20 +143,7 @@ def load_module_map( adapter_weights.update(load_file(filename)) # map the model weights to the relevant adapter weights (LoRA A and B matrices) - adapter_weight_names = set() - module_map = {} - for weight_name in weight_names: - lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" - lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" - if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: - continue - - module_map[weight_name] = { - "lora_A": (adapter_weights[lora_a_name], lora_a_name), - "lora_B": (adapter_weights[lora_b_name], lora_b_name), - } - adapter_weight_names.add(lora_a_name) - adapter_weight_names.add(lora_b_name) + module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer @@ -182,7 +169,7 @@ def compute_delta_weight( def merge_adapter_weights( model_weights: Dict[str, torch.Tensor], adapter_weights: Dict[str, torch.Tensor], - adapter_config: LoraConfig + adapter_config: "AdapterConfig" ) -> Tuple[Dict[str, torch.Tensor], Set[str]]: """ Merges the adapter weights into the model weights. @@ -190,11 +177,16 @@ def merge_adapter_weights( Args: model_weights (Dict[str, torch.Tensor]): The weights of the base model. adapter_weights (Dict[str, torch.Tensor]): The weights of the adapters. - adapter_config (LoraConfig): The configuration for the LoRA adapter. + adapter_config (AdapterConfig): The configuration for the adapter. Returns: Tuple[Dict[str, torch.Tensor], Set[str]]: A tuple containing the merged weights and the set of processed adapter weight names. """ + from lorax_server.adapters.lora import LoraConfig + + if not isinstance(adapter_config, LoraConfig): + raise ValueError(f"Unsupported adapter config type: {type(adapter_config)}") + module_mapping = defaultdict(dict) processed_adapter_weight_names = set() @@ -225,7 +217,7 @@ def merge_adapter_weights( adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r, - uses_rslora=uses_rslora(adapter_config), + uses_rslora=adapter_config.use_rslora, ) # transpose delta weight if necessary @@ -245,12 +237,11 @@ def create_merged_weight_files( adapter_source: str = "hub", ) -> List[Path]: """Creates merged weight files for the given adapter ID and filenames.""" - source = get_model_source(adapter_source, adapter_id) + api_token = None # TODO(travis): add support for API token + source = get_model_source(adapter_source, adapter_id, api_token=api_token) adapter_filenames = source.weight_files() - adapter_path = get_config_path(adapter_id, adapter_source) - api_token = None # TODO(travis): add support for API token - adapter_config = LoraConfig.from_pretrained(adapter_path, token=api_token) + adapter_config = source.load_config() if adapter_config.base_model_name_or_path != model_id: expected_config = AutoConfig.from_pretrained(model_id) model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path) @@ -308,33 +299,3 @@ def create_merged_weight_files( logger.info( f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}") return merged_weight_filenames - - -def uses_rslora(adapter_config: LoraConfig) -> bool: - """ Returns True if the adapter uses RSLora for scaling the delta weights. """ - return adapter_config.use_rslora if hasattr(adapter_config, "use_rslora") else False - - -def get_scaling_factor( - lora_alpha: int, - r: int, - uses_rslora: bool = False, -) -> float: - """Computes the scaling factor for the lora weights.""" - if uses_rslora: - return lora_alpha / (r ** 0.5) - return lora_alpha / r - - -def main(): - adapter_id = "arnavgrg/codealpaca-qlora" - adapter_config = LoraConfig.from_pretrained(adapter_id) - model_id = adapter_config.base_model_name_or_path - model_weight_filenames = weight_files(model_id, extension=".safetensors") - - merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames) - print(merged_adapter_filenames) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 9a53fe1e0..956346e39 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -11,7 +11,9 @@ from torch import nn from tqdm import tqdm -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, AdapterWeightData, RankSegments +from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata +from lorax_server.adapters.lora import BatchLoraWeights, RankSegments +from lorax_server.adapters.types import LORA from lorax_server.models.cache_manager import get_cache_manager, BLOCK_SIZE from lorax_server.utils.sgmv import get_tmp_expand_size, get_tmp_tensors, use_cutlass_shrink @@ -89,7 +91,7 @@ def get_max_graph_state(device: torch.device, adapter_layers: Tuple[str]) -> Gra adapter_weight_data = {} for layer_name in adapter_layers: - adapter_weight_data[layer_name] = AdapterWeightData( + adapter_weight_data[layer_name] = BatchLoraWeights( lora_a={}, lora_b={}, adapter_index_configs={}, @@ -166,22 +168,24 @@ def trace( # cutlass shrink uses a custom temp buffer per rank tmp_shrink = tmp_shrink[:tmp_expand_size] - adapter_weight_data[layer_name] = AdapterWeightData( - lora_a={}, - lora_b={}, - adapter_index_configs={}, - rank_data={ - max_rank: RankSegments( - rank=max_rank, - tmp_shrink=tmp_shrink, - tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size], - lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], - lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], - segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size], - segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size], - ), - } if max_rank > 0 else {}, - ) + adapter_weight_data[layer_name] = { + LORA: BatchLoraWeights( + lora_a={}, + lora_b={}, + adapter_index_configs={}, + rank_data={ + max_rank: RankSegments( + rank=max_rank, + tmp_shrink=tmp_shrink, + tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size], + lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], + lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], + segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size], + segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size], + ), + } if max_rank > 0 else {}, + ) + } input_state = GraphState( input_ids=max_input_state.input_ids[:batch_size], @@ -245,15 +249,16 @@ def forward( self.input_state.block_tables[:block_tables.shape[0], :block_tables.shape[1]] = block_tables for layer_name, weight_data in self.input_state.adapter_data.data.items(): + lora_data = weight_data[LORA] if layer_name not in adapter_data.data: # zero out all the segments - for rank_data in weight_data.rank_data.values(): + for rank_data in lora_data.rank_data.values(): rank_data.segment_starts.fill_(SEGMENT_PAD_VALUE) rank_data.segment_ends.fill_(SEGMENT_PAD_VALUE) continue source_data = adapter_data.data[layer_name] - dest_data = weight_data + dest_data = lora_data for rank, source_rank_data in source_data.rank_data.items(): dest_rank_data = dest_data.rank_data[rank] @@ -291,6 +296,9 @@ def can_use_graph( batch_size = batch.input_ids.shape[0] max_s = batch.max_seqlen + # Only allow LoRA adapters for now + adapter_keys = set(adapter_data.data.keys()) + # TODO(travis): allow using CUDA graphs with multi-rank batches return ( torch.cuda.is_available() @@ -299,6 +307,7 @@ def can_use_graph( and max_rank <= MAX_RANK and nranks <= 1 and max_rank in _allowed_ranks + and all(k == LORA for k in adapter_keys) ) def get_estimated_cache_memory(self) -> int: diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 962338c99..16128b031 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -5,7 +5,8 @@ from torch import nn from torch.nn import functional as F -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union + HAS_BITS_AND_BYTES = True try: @@ -42,8 +43,11 @@ def weight(self) -> torch.Tensor: from accelerate import init_empty_weights +from lorax_server.adapters.lora import BatchLoraWeights +from lorax_server.adapters.types import LORA from lorax_server.utils.gptq.quant_linear import QuantLinear -from lorax_server.utils.sgmv import add_lora_sgmv_cutlass, lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, has_sgmv, orient_for_rank +from lorax_server.adapters import AdapterBatchData +from lorax_server.utils.sgmv import lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, has_sgmv, orient_for_rank from lorax_server.utils.state import is_warmup HAS_EXLLAMA = True @@ -54,8 +58,6 @@ def weight(self) -> torch.Tensor: except ImportError: HAS_EXLLAMA = False -from lorax_server.utils.lora import AdapterBatchData, AdapterWeightData - # Monkey patching @classmethod @@ -502,7 +504,7 @@ def load_multi( return cls(linear) -class TensorParallelAdapterLinear(nn.Module): +class LoraLinear(nn.Module): def __init__(self, base_layer, layer_id, process_group): super().__init__() self.base_layer = base_layer @@ -519,6 +521,7 @@ def forward_layer_type( end_idx: int, ) -> torch.Tensor: data = adapter_data.data.get(layer_type) + data: BatchLoraWeights = data.get(LORA) if data is not None else None if has_sgmv() and data is not None and data.can_vectorize(self.process_group): if end_idx - start_idx != result.shape[1]: @@ -567,7 +570,7 @@ def forward_layer_type( def forward_lora( self, input: torch.Tensor, - data: AdapterWeightData, + data: BatchLoraWeights, adapter_index: int, adapter_mask: torch.Tensor, ) -> torch.Tensor: @@ -587,7 +590,7 @@ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: raise NotImplementedError("Implemented in subclasses") -class TensorParallelMultiAdapterLinear(TensorParallelAdapterLinear): +class TensorParallelMultiAdapterLinear(LoraLinear): def __init__(self, base_layer, layer_id, layer_names, sizes, process_group): super().__init__(base_layer, layer_id, process_group) self.layer_names = layer_names @@ -642,7 +645,7 @@ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: return torch.cat(gathered_tensors, dim=1) -class TensorParallelAdapterRowLinear(TensorParallelAdapterLinear): +class TensorParallelAdapterRowLinear(LoraLinear): def __init__(self, base_layer, layer_id, layer_name, process_group): super().__init__(base_layer, layer_id, process_group) self.layer_name = layer_name diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 48cb2b30e..629d2fbf5 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -1,12 +1,10 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Set +from typing import Dict, List, Set, Type import torch -from peft import LoraConfig -from torch.distributed import ProcessGroup -from lorax_server.utils.sgmv import MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank +from lorax_server.adapters.weights import AdapterWeights, BatchAdapterWeights # Constants @@ -20,206 +18,3 @@ DOWN_PROJ = "down_proj" LM_HEAD = "lm_head" - -EMPTY_TENSOR = torch.tensor([]) - - -@dataclass -class RankSegments: - rank: int - tmp_shrink: torch.Tensor - tmp_expand: torch.Tensor - lora_a_ptr: torch.Tensor - lora_b_ptr: torch.Tensor - segment_starts: torch.Tensor - segment_ends: torch.Tensor - - -@dataclass -class AdapterWeightData: - lora_a: Dict[int, torch.Tensor] - lora_b: Dict[int, torch.Tensor] - adapter_index_configs: Dict[int, LoraConfig] - rank_data: Dict[int, RankSegments] - - def has_adapter(self, adapter_index: int) -> bool: - return adapter_index in self.adapter_index_configs - - def can_vectorize(self, pg: ProcessGroup) -> bool: - return all( - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM - for rank_data in self.rank_data.values() - ) - - -@dataclass -class AdapterBatchMetadata: - # [batch_size] - adapter_indices: torch.Tensor - - # [num_adapters] - adapter_set: Set[int] - - # [num_segments + 1] - adapter_segments: torch.Tensor - - # [num_segments] - # maps from segment index to adapter index, i.e.: - # segment_indices[s] == adapter_indices[i] - segment_indices: List[int] - - -@dataclass -class AdapterBatchData: - meta: AdapterBatchMetadata - data: Dict[str, AdapterWeightData] - - @staticmethod - def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, "BatchedLoraWeights"]) -> "AdapterBatchData": - data = {} - for k, v in weights.items(): - if v.is_empty(): - continue - data[k] = v.get_data(meta) - return AdapterBatchData(meta=meta, data=data) - - def ranks(self) -> Set[int]: - return set( - rank_data.rank - for layer in self.data.values() - for rank_data in layer.rank_data.values() - ) - - @property - def max_rank(self) -> int: - ranks = self.ranks() - return max(ranks) if len(ranks) > 0 else 0 - - -class MergedLoraWeights: - """LoRA weights for a single adapter merged across all layers.""" - - def __init__( - self, - weights_a: List[torch.Tensor], - weights_b: List[torch.Tensor], - adapter_config: LoraConfig, - ): - self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 - self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 - - # [num_layers, hidden_size, r] - weights_a = [ - orient_for_rank(w, w.size(1)).contiguous() - for w in weights_a - ] - self.weights_a = torch.stack(weights_a) - - # [num_layers, r, hidden_size] - self.weights_b = torch.stack(weights_b) - - self.adapter_config = adapter_config - - -class BatchedLoraWeights: - """LoRA weights for multiple adapters.""" - - def __init__(self): - self.lora_weights: Dict[int, MergedLoraWeights] = {} - - def add_adapter(self, adapter_idx: int, weights: MergedLoraWeights): - self.lora_weights[adapter_idx] = weights - - def remove_adapter(self, adapter_idx: int): - if adapter_idx not in self.lora_weights: - return - del self.lora_weights[adapter_idx] - - def is_empty(self) -> bool: - return len(self.lora_weights) == 0 - - def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: - """ - Get the adapter weight data for a given metadata. - - Args: - meta (AdapterBatchMetadata): The metadata for the adapter batch. - - Returns: - AdapterWeightData: The adapter weight data. - - """ - first_weights = list(self.lora_weights.values())[0] - device = first_weights.weights_a.device - segment_indices = meta.segment_indices - - lora_a = { - idx: self.lora_weights[idx].weights_a - for idx in segment_indices - if idx in self.lora_weights - } - lora_a_ptr = torch.tensor( - [ - ( - self.lora_weights[idx].weights_a.data_ptr() - if idx in self.lora_weights - else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b = { - idx: self.lora_weights[idx].weights_b - for idx in segment_indices - if idx in self.lora_weights - } - lora_b_ptr = torch.tensor( - [ - ( - self.lora_weights[idx].weights_b.data_ptr() - if idx in self.lora_weights - else EMPTY_TENSOR.data_ptr() - ) for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - - adapter_index_configs = { - idx: self.lora_weights[idx].adapter_config - for idx in segment_indices - if idx in self.lora_weights - } - - rank_indices = defaultdict(list) - for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx not in self.lora_weights: - continue - rank_indices[self.lora_weights[adapter_idx].lora_a_r].append(segment_idx) - - rank_data = {} - for rank, indices in rank_indices.items(): - lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors( - lora_a_ptr_indices.size(0), - rank, - device - ) - - rank_data[rank] = RankSegments( - rank=rank, - tmp_shrink=tmp_shrink, - tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr_indices, - lora_b_ptr=lora_b_ptr[indices], - segment_starts=meta.adapter_segments[indices], - segment_ends=meta.adapter_segments[[i+1 for i in indices]], - ) - - return AdapterWeightData( - lora_a=lora_a, - lora_b=lora_b, - adapter_index_configs=adapter_index_configs, - rank_data=rank_data, - ) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 938a7b3a9..e2ccfccfa 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union import torch -from peft import LoraConfig from lorax_server.pb.generate_pb2 import ( AdapterParameters, @@ -14,6 +13,7 @@ from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune if TYPE_CHECKING: + from lorax_server.adapters.lora import LoraConfig from lorax_server.utils.adapter import ModuleMap @@ -102,9 +102,9 @@ def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torc def merge_adapters( - adapters: List[Tuple["ModuleMap", LoraConfig]], + adapters: List[Tuple["ModuleMap", "LoraConfig"]], merge_params: AdapterParameters, -) -> Tuple["ModuleMap", LoraConfig]: +) -> Tuple["ModuleMap", "LoraConfig"]: strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() weights = merge_params.weights @@ -154,7 +154,7 @@ def merge_adapters( return merged_module_map, merged_lora_config -def _validate_lora_configs(lora_configs: List[LoraConfig]): +def _validate_lora_configs(lora_configs: List["LoraConfig"]): # check that all configs have the same rank ranks = set(lora_config.r for lora_config in lora_configs) if len(ranks) > 1: @@ -164,7 +164,7 @@ def _validate_lora_configs(lora_configs: List[LoraConfig]): raise ValueError("unable to merge adapters, lora configs have no target modules") -def _merge_lora_configs(lora_configs: List[LoraConfig]) -> LoraConfig: +def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": merged_lora_config = copy.copy(lora_configs[0]) # merge target modules as a union operation diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index 8c909b8a6..b32155c63 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -162,7 +162,11 @@ def __init__(self, model_id: str, revision: Optional[str] = None, extension: str self.model_id = model_id self.revision = revision self.extension = extension - self.api_token = api_token + self._api_token = api_token + + @property + def api_token(self) -> Optional[str]: + return self._api_token def remote_weight_files(self, extension: str = None): extension = extension or self.extension @@ -178,3 +182,11 @@ def download_weights(self, filenames): def download_model_assets(self): filenames = self.remote_weight_files() return self.download_weights(filenames) + + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: + try: + return Path(hf_hub_download(self.model_id, revision=None, filename=filename)) + except Exception as e: + if ignore_errors: + return None + raise e diff --git a/server/lorax_server/utils/sources/local.py b/server/lorax_server/utils/sources/local.py index e89ffe3f5..cd55d2e5b 100644 --- a/server/lorax_server/utils/sources/local.py +++ b/server/lorax_server/utils/sources/local.py @@ -19,7 +19,7 @@ from .source import BaseModelSource, try_to_load_from_cache -def get_model_local_dir(model_id: str): +def get_model_local_dir(model_id: str) -> Path: if os.path.isabs(model_id): return Path(model_id) @@ -37,6 +37,10 @@ def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = self.revision = revision self.extension = extension + @property + def api_token(self) -> Optional[str]: + return None + def remote_weight_files(self, extension: str = None): return [] @@ -63,5 +67,15 @@ def download_weights(self, filenames: List[str]): def download_model_assets(self): return [] - def get_local_path(self, model_id: str): - return get_model_local_dir(model_id) \ No newline at end of file + def get_local_path(self, model_id: str) -> Path: + return get_model_local_dir(model_id) + + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: + path = get_model_local_dir(self.model_id) / filename + if not path.exists(): + if ignore_errors: + return None + raise FileNotFoundError( + f"File {filename} of model {self.model_id} not found in {path}" + ) + return path diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 9c0fee30f..9ef33f590 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -223,6 +223,10 @@ def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = self.revision = revision self.extension = extension self.bucket = _get_bucket_resource(bucket) + + @property + def api_token(self) -> Optional[str]: + return None def remote_weight_files(self, extension: str = None): extension = extension or self.extension @@ -241,3 +245,13 @@ def download_model_assets(self): def get_local_path(self, model_id: str): _, model_id = _get_bucket_and_model_id(model_id) return get_s3_model_local_dir(model_id) + + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: + filenames = [filename] + try: + paths = download_files_from_s3(self.bucket, filenames, self.model_id, self.revision) + return paths[0] + except FileNotFoundError as e: + if ignore_errors: + return None + raise e diff --git a/server/lorax_server/utils/sources/source.py b/server/lorax_server/utils/sources/source.py index f847309c0..14867d97f 100644 --- a/server/lorax_server/utils/sources/source.py +++ b/server/lorax_server/utils/sources/source.py @@ -1,8 +1,12 @@ +from abc import abstractmethod import json import os from typing import Optional, List from pathlib import Path +from lorax_server.adapters import load_adapter_config +from lorax_server.adapters.config import AdapterConfig + def try_to_load_from_cache( repo_cache: Path, revision: Optional[str], filename: str @@ -39,15 +43,24 @@ def try_to_load_from_cache( class BaseModelSource: + @property + @abstractmethod + def api_token(self) -> Optional[str]: + pass + + @abstractmethod def remote_weight_files(self, extension: str = None): - raise NotImplementedError + pass + @abstractmethod def weight_files(self, extension: str = None) -> List[Path]: - raise NotImplementedError + pass + @abstractmethod def download_weights(self, filenames: List[str]): - raise NotImplementedError + pass + @abstractmethod def download_model_assets(self): """ The reason we need this function is that for s3 we need to download all the model files whereas for @@ -55,7 +68,11 @@ def download_model_assets(self): for other future sources we might need something different. So this function will take the necessary steps to download the needed files for any source """ - raise NotImplementedError + pass + + @abstractmethod + def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: + pass def get_weight_bytes(self) -> int: total_size = 0 @@ -113,3 +130,9 @@ def get_weight_bytes(self) -> int: total_size += total_size_bytes return total_size + + def load_config(self) -> AdapterConfig: + config_path = self.download_file("config.json", ignore_errors=True) + adapter_config_path = self.download_file("adapter_config.json", ignore_errors=True) + return load_adapter_config(config_path, adapter_config_path, self.api_token) + diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py index 793b11d37..c7b2f48b9 100644 --- a/server/tests/utils/test_adapter.py +++ b/server/tests/utils/test_adapter.py @@ -1,6 +1,6 @@ import torch -from peft import LoraConfig +from lorax_server.adapters.lora import LoraConfig from lorax_server.utils.adapter import merge_adapter_weights @@ -33,7 +33,7 @@ def test_merge_adapter_weights(): [13.5000, 18.0000, 22.5000], [21.5000, 28.0000, 34.5000] ]) - adapter_config = LoraConfig(r=2, lora_alpha=1, fan_in_fan_out=False) + adapter_config = LoraConfig(base_model_name_or_path="", r=2, target_modules=None, lora_alpha=1, fan_in_fan_out=False, use_rslora=False) merged_weights, processed_adapter_weight_names = merge_adapter_weights(model_weights, adapter_weights, adapter_config) assert len(merged_weights) == 1 diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index fc5a595d1..ea33ed8cd 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -5,7 +5,9 @@ import torch from peft import LoraConfig -from lorax_server.utils.lora import AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights +from lorax_server.adapters.lora import LoraWeights +from lorax_server.adapters.types import LORA +from lorax_server.adapters.weights import AdapterBatchMetadata, LayerAdapterWeights from lorax_server.utils.sgmv import MIN_RANK_CUSTOM @@ -17,12 +19,12 @@ def test_batched_lora_weights(lora_ranks: List[int]): # batch meta is hardcoded with this assumption below assert len(lora_ranks) == 2 - batched_weights = BatchedLoraWeights() + batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() h = 1024 for idx, lora_rank in enumerate(lora_ranks): - weights = MergedLoraWeights( + weights = LoraWeights( weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)], adapter_config=LoraConfig(r=lora_rank), @@ -33,7 +35,7 @@ def test_batched_lora_weights(lora_ranks: List[int]): batched_weights.add_adapter(idx, weights) assert not batched_weights.is_empty() - assert len(batched_weights.lora_weights) == 2 + assert len(batched_weights.adapter_weights) == 2 meta = AdapterBatchMetadata( adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), @@ -42,8 +44,8 @@ def test_batched_lora_weights(lora_ranks: List[int]): segment_indices=[0, 1, 0, 1], ) - with mock.patch("lorax_server.utils.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta) + with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): + data = batched_weights.get_data(meta).get(LORA) assert len(data.lora_a) == 2 assert data.lora_a.keys() == meta.adapter_set