From 60ad527763ba0c6c13aabb811dd8214c982503d6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 18 Jan 2024 12:56:22 -0800 Subject: [PATCH 01/23] WIP: lora merge --- server/lorax_server/utils/adapter.py | 3 + server/lorax_server/utils/merges/__init__.py | 0 .../lorax_server/utils/merges/strategies.py | 107 ++++++++++++++++++ server/lorax_server/utils/merges/utils.py | 103 +++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 server/lorax_server/utils/merges/__init__.py create mode 100644 server/lorax_server/utils/merges/strategies.py create mode 100644 server/lorax_server/utils/merges/utils.py diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 51dfa9021..0b36e3f10 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -21,6 +21,9 @@ BASE_MODEL_ADAPTER_ID = "__base_model__" +ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] + + @lru_cache(maxsize=128) def load_module_map(model_id, adapter_id, adapter_source, weight_names, api_token): # TODO(geoffrey): refactor this and merge parts of this function with diff --git a/server/lorax_server/utils/merges/__init__.py b/server/lorax_server/utils/merges/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py new file mode 100644 index 000000000..d3021f8a5 --- /dev/null +++ b/server/lorax_server/utils/merges/strategies.py @@ -0,0 +1,107 @@ +from abc import ABC +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch +from peft import LoraConfig + +from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune +from lorax_server.utils.adapter import ModuleMap + + +def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total"): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # elect sign + majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float): + self.density = density + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total"): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # elect sign + majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple[ModuleMap, LoraConfig]], + merge_config: Dict, +) -> Tuple[ModuleMap, LoraConfig]: + merge_config = merge_config.copy() + strategy_name = merge_config.pop("strategy") + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps = defaultdict(dict) + lora_configs = [] + + for module_map, lora_config in adapters: + + for weight_name, weights in module_map.items(): + for k, (param, param_name) in weights.items(): + module_maps[weight_name][k] = (param, param_name) + + lora_configs.append(lora_config) + + diff --git a/server/lorax_server/utils/merges/utils.py b/server/lorax_server/utils/merges/utils.py new file mode 100644 index 000000000..88e2a2989 --- /dev/null +++ b/server/lorax_server/utils/merges/utils.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask(tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) From aff0903dad432182cb924518a5fef6c52f27343f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 25 Jan 2024 17:00:52 -0800 Subject: [PATCH 02/23] Merge --- .../lorax_server/utils/merges/strategies.py | 59 ++++++++++++++++--- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index d3021f8a5..e354eb131 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -1,6 +1,7 @@ from abc import ABC from collections import defaultdict -from typing import Dict, List, Tuple +import copy +from typing import Dict, List, Tuple, Type import torch from peft import LoraConfig @@ -77,7 +78,7 @@ def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torc return mixed_task_tensors -strategy_registry = { +strategy_registry: Dict[str, Type[MergeStrategy]] = { "linear": LinearMerge, "ties": TiesMerge, "dare_linear": DareLinearMerge, @@ -91,17 +92,59 @@ def merge_adapters( ) -> Tuple[ModuleMap, LoraConfig]: merge_config = merge_config.copy() strategy_name = merge_config.pop("strategy") + + weights = merge_config.pop("weights", None) + if weights is None: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + merge_strategy = strategy_registry[strategy_name](**merge_config) - module_maps = defaultdict(dict) + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) lora_configs = [] + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors for module_map, lora_config in adapters: - - for weight_name, weights in module_map.items(): - for k, (param, param_name) in weights.items(): - module_maps[weight_name][k] = (param, param_name) - + for weight_name, data in module_map.items(): + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) lora_configs.append(lora_config) + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: ModuleMap = defaultdict(dict) + for weight_name, data in module_maps.items(): + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, weights=weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List[LoraConfig]): + # check that all configs have the same rank + ranks = set(lora_config.rank for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError(f"unable to merge adapters, lora configs have different ranks: {ranks}") + + # check that all configs have the same target modules + target_modules = set(" | ".join(lora_config.target_modules for lora_config in lora_configs)) + if len(target_modules) > 1: + raise ValueError(f"unable to merge adapters, lora configs have different target modules: {target_modules}") + +def _merge_lora_configs(lora_configs: List[LoraConfig]) -> LoraConfig: + # for now, due to rank and target constraints, we can just return one config + # may revisit this in the future if we loosen these constraints + return lora_configs[0] From ecc38dcc9a255236390c23e39e12b32c99013b78 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 25 Jan 2024 22:08:23 -0800 Subject: [PATCH 03/23] Proto and refactor --- proto/generate.proto | 39 +++++++++++++++++++ .../lorax_server/utils/merges/strategies.py | 35 ++++++++++------- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index e7afe4f26..bae950b82 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -232,6 +232,45 @@ enum AdapterSource { PBASE = 3; } +enum MergeStrategy { + /// Linear combination of adapters + LINEAR = 0; + + /// TIES method for combining adapters + TIES = 1; + + /// DARE method for combining adapters + DARE_LINEAR = 2; + + /// DARE + TIES method for combining adapters + DARE_TIES = 3; +} + +enum MajoritySignMethod { + /// Total method + TOTAL = 0; + + /// Frequency method + FREQUENCY = 1; +} + +message AdapterParameters { + /// Adapter IDs + repeated string adapter_ids = 1; + + /// Adapter weights for merging + repeated float weights = 2; + + /// Merge strategy (default: linear) + MergeStrategy merge_strategy = 3; + + /// [0, 1], 0: full pruning, 1: no pruning + float density = 4; + + /// Majority sign method (default: total) + MajoritySignMethod majority_sign_method = 5; +} + message DownloadAdapterRequest { /// Adapter ID string adapter_id = 1; diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index e354eb131..43d1d00ad 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -1,6 +1,5 @@ from abc import ABC from collections import defaultdict -import copy from typing import Dict, List, Tuple, Type import torch @@ -22,25 +21,29 @@ def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor class MergeStrategy(ABC): - def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() class LinearMerge(MergeStrategy): - def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: - weighted_task_tensors = _apply_weights(task_tensors, weights) + def __init__(self, weights: torch.Tensor, **kwargs): + self.weights = weights + + def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, self.weights) return weighted_task_tensors.sum(dim=0) class TiesMerge(MergeStrategy): - def __init__(self, density: float, majority_sign_method: str = "total"): + def __init__(self, weights: torch.Tensor, density: float, majority_sign_method: str = "total", **kwargs): + self.weights = weights self.density = density self.majority_sign_method = majority_sign_method - def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, weights) + weighted_task_tensors = _apply_weights(task_tensors, self.weights) # elect sign majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) @@ -50,25 +53,27 @@ def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torc class DareLinearMerge(MergeStrategy): - def __init__(self, density: float): + def __init__(self, weights: torch.Tensor, density: float, **kwargs): + self.weights = weights self.density = density - def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, weights) + weighted_task_tensors = _apply_weights(task_tensors, self.weights) return weighted_task_tensors.sum(dim=0) class DareTiesMerge(MergeStrategy): - def __init__(self, density: float, majority_sign_method: str = "total"): + def __init__(self, weights: torch.Tensor, density: float, majority_sign_method: str = "total", **kwargs): + self.weights = weights self.density = density self.majority_sign_method = majority_sign_method - def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, weights) + weighted_task_tensors = _apply_weights(task_tensors, self.weights) # elect sign majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) @@ -99,7 +104,7 @@ def merge_adapters( else: weights = torch.tensor(weights) - merge_strategy = strategy_registry[strategy_name](**merge_config) + merge_strategy = strategy_registry[strategy_name](weights=weights, **merge_config) module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(list)) @@ -123,7 +128,7 @@ def merge_adapters( for weight_name, data in module_maps.items(): for k, param_data in data.items(): for param_name, tensors in param_data.items(): - merged_tensor = merge_strategy.merge(tensors, weights=weights) + merged_tensor = merge_strategy.merge(tensors) merged_module_map[weight_name][k] = (merged_tensor, param_name) # merge lora configs From 88d9a3f9826747b2031b52be715ded5d9424e6de Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 26 Jan 2024 16:27:38 -0800 Subject: [PATCH 04/23] Plumb adapter parameters through server --- proto/generate.proto | 34 ++---- server/lorax_server/models/model.py | 97 ++++++++------- server/lorax_server/server.py | 110 +++++++++--------- server/lorax_server/utils/adapter.py | 80 ++++++++++++- .../lorax_server/utils/merges/strategies.py | 14 ++- 5 files changed, 206 insertions(+), 129 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index bae950b82..1743b0b3a 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -272,8 +272,8 @@ message AdapterParameters { } message DownloadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Token for external API (predibase / HuggingFace) @@ -281,15 +281,13 @@ message DownloadAdapterRequest { } message DownloadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; + /// True if download occurred, false if skipped + bool downloaded = 1; } message LoadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Adapter index @@ -299,17 +297,13 @@ message LoadAdapterRequest { } message LoadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; - /// Adapter index - uint32 adapter_index = 3; + /// True if load occurred, false if skipped + bool loaded = 1; } message OffloadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Adapter index @@ -317,10 +311,6 @@ message OffloadAdapterRequest { } message OffloadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; - /// Adapter index - uint32 adapter_index = 3; + /// True if offload occurred, false if skipped + bool offloaded = 1; } diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 9af1be86d..c76d57558 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -9,8 +9,8 @@ from transformers import PreTrainedTokenizerBase from lorax_server.models.types import Batch, GeneratedText -from lorax_server.pb.generate_pb2 import InfoResponse -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map +from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_and_merge_adapters from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.lora import BatchedLoraWeights, MergedLoraWeights from lorax_server.utils.weights import shard_on_dim @@ -46,10 +46,11 @@ def __init__( self.sliding_window = sliding_window # This may be set to False in the subclass constructor - self.adapter_id = adapter_id self.dynamic_adapter_loading_enabled = dynamic_adapter_loading_enabled self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) self.target_to_layer = self.adapter_target_to_layer() + self.loaded_adapters = set() + self.static_adapter_id = adapter_id self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -137,48 +138,51 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return False - def load_adapter(self, adapter_id, adapter_source, adapter_index, api_token): + def load_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + api_token: str, + ): """Physically loads the adapter weights into the model. adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded into model. Otherwise, the adapter weights are merged into the model weights on the fly. """ - if adapter_id == BASE_MODEL_ADAPTER_ID: + if adapter_index in self.loaded_adapters: + # Adapter already loaded return if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") if not self.dynamic_adapter_loading_enabled: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") + + logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_and_merge_adapters( + self.model_id, adapter_parameters, adapter_source, adapter_index, weight_names, api_token + ) - # If we are doing dynamic adapter loading, then we need to reset the weights - if adapter_id == self.adapter_id: - return - elif adapter_id != BASE_MODEL_ADAPTER_ID: - logger.info(f"Loading adapter weights into model: {adapter_id}") - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( - self.model_id, adapter_id, adapter_source, weight_names, api_token + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + self.load_batched_adapter_weights( + module_map, adapter_config, adapter_index, layer_name, unused_weight_names ) + + if len(unused_weight_names) > 0: + logger.warning(f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}") + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - self.load_batched_adapter_weights( - module_map, adapter_config, adapter_index, layer_name, unused_weight_names - ) - - if len(unused_weight_names) > 0: - logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.adapter_id = adapter_id + self.loaded_adapters.add(adapter_index) def shard_lora_weights( self, @@ -246,25 +250,28 @@ def load_batched_adapter_weights( q_lora_weights = self.batched_lora_weights[layer_type] q_lora_weights.add_adapter(adapter_index, q_lora_merged) - def offload_adapter(self, adapter_id, adapter_source, adapter_index): + def offload_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + ): """Offloads the adapter weights from GPU to CPU or disk.""" + if adapter_index not in self.loaded_adapters: + # Adapter already offloaded + return + if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") if not self.dynamic_adapter_loading_enabled: - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - for layer_name in self.adapter_layers: - if layer_name in self.batched_lora_weights: - self.batched_lora_weights[layer_name].remove_adapter(adapter_index) + for layer_name in self.adapter_layers: + if layer_name in self.batched_lora_weights: + self.batched_lora_weights[layer_name].remove_adapter(adapter_index) - self.adapter_id = BASE_MODEL_ADAPTER_ID + self.loaded_adapters.remove(adapter_index) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5225b9ce1..4d381e720 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -19,7 +19,7 @@ from lorax_server.pb import generate_pb2_grpc, generate_pb2 from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3 -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, is_base_model class LoraxService(generate_pb2_grpc.LoraxServiceServicer): @@ -126,51 +126,56 @@ async def Decode(self, request, context): batch=next_batch.to_pb() if next_batch else None, ) - async def DownloadAdapter(self, request, context): - adapter_id = request.adapter_id - if adapter_id == BASE_MODEL_ADAPTER_ID: + async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): logger.info("No adapter to download for base model. Skipping.") - return generate_pb2.DownloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - ) + return generate_pb2.DownloadAdapterResponse(downloaded=False) api_token = request.api_token adapter_source = _adapter_source_enum_to_string(request.adapter_source) - if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) - adapter_source = S3 - try: - if adapter_source == HUB: - # Quick auth check on the repo against the token - HfApi(token=api_token).model_info(adapter_id, revision=None) - - # fail fast if ID is not an adapter (i.e. it is a full model) - # TODO(geoffrey): do this for S3– can't do it this way because the - # files are not yet downloaded locally at this point. - config_path = get_config_path(adapter_id, adapter_source) - PeftConfig.from_pretrained(config_path, token=api_token) - - download_weights(adapter_id, source=adapter_source, api_token=api_token) - return generate_pb2.DownloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - ) - except Exception: - logger.exception("Error when downloading adapter") - - if adapter_source != LOCAL: - # delete safetensors files if there is an issue downloading or converting - # the weights to prevent cache hits by subsequent calls - try: - local_path = get_local_dir(adapter_id, adapter_source) - shutil.rmtree(local_path) - except Exception as e: - logger.warning(f"Error cleaning up safetensors files after " - f"download error: {e}\nIgnoring.") - raise + for adapter_id in adapter_parameters.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + logger.info("No adapter to download for base model. Skipping.") + continue + + if adapter_source == PBASE: + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + adapter_source = S3 + try: + if adapter_source == HUB: + # Quick auth check on the repo against the token + HfApi(token=api_token).model_info(adapter_id, revision=None) + + # fail fast if ID is not an adapter (i.e. it is a full model) + # TODO(geoffrey): do this for S3– can't do it this way because the + # files are not yet downloaded locally at this point. + config_path = get_config_path(adapter_id, adapter_source) + PeftConfig.from_pretrained(config_path, token=api_token) + + download_weights(adapter_id, source=adapter_source, api_token=api_token) + except Exception: + logger.exception("Error when downloading adapter") + + if adapter_source != LOCAL: + # delete safetensors files if there is an issue downloading or converting + # the weights to prevent cache hits by subsequent calls + try: + local_path = get_local_dir(adapter_id, adapter_source) + shutil.rmtree(local_path) + except Exception as e: + logger.warning(f"Error cleaning up safetensors files after " + f"download error: {e}\nIgnoring.") + raise + + return generate_pb2.DownloadAdapterResponse(downloaded=True) - async def LoadAdapter(self, request, context): + async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): + logger.info("No adapter to load for base model. Skipping.") + return generate_pb2.LoadAdapterResponse(loaded=False) + try: adapter_id = request.adapter_id adapter_source = _adapter_source_enum_to_string(request.adapter_source) @@ -181,27 +186,24 @@ async def LoadAdapter(self, request, context): adapter_source = S3 self.model.load_adapter(adapter_id, adapter_source, adapter_index, api_token) - return generate_pb2.LoadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - adapter_index=adapter_index, - ) + return generate_pb2.LoadAdapterResponse(loaded=True) except Exception: logger.exception("Error when loading adapter") raise - async def OffloadAdapter(self, request, context): + async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): + logger.info("No adapter to offload for base model. Skipping.") + return generate_pb2.OffloadAdapterResponse(offloaded=False) + try: - adapter_id = request.adapter_id + adapter_idx = request.adapter_index adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index - self.model.offload_adapter(adapter_id, adapter_source, adapter_index) + self.model.offload_adapter(adapter_idx, adapter_source, adapter_index) - return generate_pb2.OffloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - adapter_index=adapter_index, - ) + return generate_pb2.OffloadAdapterResponse(offloaded=True) except Exception: logger.exception("Error when offloading adapter") raise diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 0b36e3f10..993fd691e 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import os from collections import defaultdict from functools import lru_cache @@ -11,11 +12,13 @@ from peft import LoraConfig from peft.utils import transpose from safetensors.torch import load_file, save_file -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from tqdm import tqdm from filelock import FileLock -from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.pb import generate_pb2 +from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.utils.merges.strategies import merge_adapters BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -24,8 +27,79 @@ ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] +@dataclass +class AdapterParametersContainer: + adapter_parameters: generate_pb2.AdapterParameters + adapter_source: str + adapter_index: int + + def __hash__(self) -> int: + return self.adapter_index + + +def is_base_model(adapter_parameters: generate_pb2.AdapterParameters) -> bool: + if len(adapter_parameters.adapter_ids) != 1: + return False + return adapter_parameters.adapter_ids[0] == BASE_MODEL_ADAPTER_ID + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: generate_pb2.AdapterParameters, + adapter_source: str, + adapter_index: int, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_ids) == 1: + return load_module_map( + model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token + ) + + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_source, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, api_token) + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + weight_names = set() + tokenizer = None + for adapter_id in params.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( + model_id, adapter_id, adapter_params.adapter_source, weight_names, api_token, + ) + + adapters_to_merge.append((module_map, adapter_config)) + weight_names = weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, weight_names, tokenizer + + @lru_cache(maxsize=128) -def load_module_map(model_id, adapter_id, adapter_source, weight_names, api_token): +def load_module_map( + model_id: str, + adapter_id: str, + adapter_source: str, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with # lorax_server/utils/adapter.py::create_merged_weight_files source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 43d1d00ad..28a824542 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -5,6 +5,7 @@ import torch from peft import LoraConfig +from lorax_server.pb.generate_pb2 import AdapterParameters from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune from lorax_server.utils.adapter import ModuleMap @@ -93,17 +94,20 @@ def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: def merge_adapters( adapters: List[Tuple[ModuleMap, LoraConfig]], - merge_config: Dict, + merge_params: AdapterParameters, ) -> Tuple[ModuleMap, LoraConfig]: - merge_config = merge_config.copy() - strategy_name = merge_config.pop("strategy") + strategy_name = merge_params.merge_strategy - weights = merge_config.pop("weights", None) - if weights is None: + weights = merge_params.weights + if not weights: weights = torch.ones(len(adapters)) else: weights = torch.tensor(weights) + merge_config = { + "density": merge_params.density, + "majority_sign_method": merge_params.majority_sign_method, + } merge_strategy = strategy_registry[strategy_name](weights=weights, **merge_config) module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( From 3d8a46b02925f05ae3b2547271ced8326dc61199 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 27 Jan 2024 13:23:57 -0800 Subject: [PATCH 05/23] WIP: plumb adapters in router --- router/client/src/client.rs | 30 ++++++++++----------- router/client/src/lib.rs | 4 +-- router/client/src/sharded_client.rs | 23 ++++++++-------- router/src/adapter.rs | 34 ++++++++++++++++------- router/src/lib.rs | 25 +++++++++++++++++ router/src/loader.rs | 2 +- router/src/queue.rs | 4 +-- router/src/validation.rs | 42 +++++++++++++++++++++++++---- 8 files changed, 119 insertions(+), 45 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 68b000f09..36608f01e 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -182,25 +182,25 @@ impl Client { /// Downloads the weights for an adapter. pub async fn download_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(DownloadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), api_token: api_token, }) .inject_context(); let response = self.stub.download_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.downloaded) } else { let err_string = format!( "Invalid source '{}' when downloading adapter '{}'", - adapter_source, adapter_id + adapter_source, adapter_parameters.adapter_ids.join(" , ") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -210,27 +210,27 @@ impl Client { /// Physically loads the weights into the model for an adapter pub async fn load_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, api_token: Option, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(LoadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), adapter_index, api_token: api_token, }) .inject_context(); let response = self.stub.load_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.loaded) } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_id + adapter_source, adapter_parameters.adapter_ids.join(" , ") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -240,25 +240,25 @@ impl Client { /// Offloads adapter the weights from GPU to CPU or disk pub async fn offload_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(OffloadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), adapter_index, }) .inject_context(); let response = self.stub.offload_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.offloaded) } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_id + adapter_source, adapter_parameters.adapter_ids.join(" , ") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 57e9b5e5f..fdb43e67c 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - PrefillTokens, Request, StoppingCriteriaParameters, + AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, + NextTokenChooserParameters, PrefillTokens, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 75bfd5c3f..f0068a1ab 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,3 +1,4 @@ +use crate::pb::generate::v1::AdapterParameters; /// Multi shard Client use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; @@ -149,30 +150,30 @@ impl ShardedClient { pub async fn download_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { // Only download the adapter with one client, since they share a single disk self.clients[0] - .download_adapter(adapter_id, adapter_source, api_token) + .download_adapter(adapter_parameters, adapter_source, api_token) .await } pub async fn load_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, api_token: Option, - ) -> Result { + ) -> Result { // Load the adapter in all clients since there is sharding done between them let futures: Vec<_> = self .clients .iter_mut() .map(|client| { Box::pin(client.load_adapter( - adapter_id.clone(), + adapter_parameters.clone(), adapter_source.clone(), adapter_index, api_token.clone(), @@ -183,7 +184,7 @@ impl ShardedClient { match join_all(futures) .await .into_iter() - .collect::>>() + .collect::>>() { Ok(mut results) => { // Return the first adapter id @@ -195,17 +196,17 @@ impl ShardedClient { pub async fn offload_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, - ) -> Result { + ) -> Result { // Load the adapter in all clients since there is sharding done between them let futures: Vec<_> = self .clients .iter_mut() .map(|client| { Box::pin(client.offload_adapter( - adapter_id.clone(), + adapter_parameters.clone(), adapter_source.clone(), adapter_index, )) @@ -215,7 +216,7 @@ impl ShardedClient { match join_all(futures) .await .into_iter() - .collect::>>() + .collect::>>() { Ok(mut results) => { // Return the first adapter id diff --git a/router/src/adapter.rs b/router/src/adapter.rs index a7d81156f..2ae3286e5 100644 --- a/router/src/adapter.rs +++ b/router/src/adapter.rs @@ -1,4 +1,6 @@ -/// Adapter utils +use std::hash; + +use crate::AdapterParameters; /// "adapter ID" for the base model. The base model does not have an adapter ID, /// but we reason about it in the same way. This must match the base model ID @@ -9,10 +11,10 @@ pub const BASE_MODEL_ADAPTER_ID: &str = "__base_model__"; /// from within the proto definition, or lib.rs pub const DEFAULT_ADAPTER_SOURCE: &str = "hub"; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub(crate) struct Adapter { - /// name of adapter - id: String, + /// adapter parameters + params: AdapterParameters, /// source (enforced at proto level) source: String, /// index of the adapter @@ -22,17 +24,17 @@ pub(crate) struct Adapter { } impl Adapter { - pub(crate) fn new(id: String, source: String, index: u32, api_token: Option) -> Self { + pub(crate) fn new(params: AdapterParameters, source: String, index: u32, api_token: Option) -> Self { Self { - id, + params, source, index, api_token, } } - pub(crate) fn id(&self) -> &str { - &self.id + pub(crate) fn params(&self) -> &AdapterParameters { + &self.params } pub(crate) fn source(&self) -> &str { @@ -49,6 +51,20 @@ impl Adapter { pub(crate) fn as_string(&self) -> String { // format ":" - format!("{}:{}", self.source, self.id) + format!("{}:{}", self.source, self.params.adapter_ids.join(",")) + } +} + +impl hash::Hash for Adapter { + fn hash(&self, state: &mut H) { + self.index.hash(state); + } +} + +impl Eq for Adapter {} + +impl PartialEq for Adapter { + fn eq(&self, other: &Self) -> bool { + self.index == other.index } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 7aefd84e0..698212f56 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -65,6 +65,25 @@ pub struct Info { pub docker_label: Option<&'static str>, } +#[derive(Clone, Debug, Deserialize, ToSchema, Default)] +pub(crate) struct AdapterParameters { + #[serde(default)] + #[schema(inline, example = json ! (["arnavgrg/codealpaca-qlora"]))] + pub adapter_ids: Vec, + #[serde(default)] + #[schema(inline, example = json ! ([0.25, 0.75]))] + pub weights: Vec, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "linear")] + pub merge_strategy: Option, + #[serde(default)] + #[schema(nullable = false, default = 0.0, example = 0.5)] + pub density: f32, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "total")] + pub majority_sign_method: Option, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] @@ -78,6 +97,9 @@ pub(crate) struct GenerateParameters { #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, #[serde(default)] + #[schema(nullable = true, default = "null")] + pub adapter_parameters: Option, + #[serde(default)] #[schema( nullable = true, default = "null", @@ -169,6 +191,7 @@ fn default_parameters() -> GenerateParameters { GenerateParameters { adapter_id: None, adapter_source: None, + adapter_parameters: None, api_token: None, best_of: None, temperature: None, @@ -470,6 +493,7 @@ impl From for CompatGenerateRequest { parameters: GenerateParameters { adapter_id: req.model.parse().ok(), adapter_source: None, + adapter_parameters: None, api_token: None, best_of: req.best_of.map(|x| x as usize), temperature: req.temperature, @@ -503,6 +527,7 @@ impl From for CompatGenerateRequest { parameters: GenerateParameters { adapter_id: req.model.parse().ok(), adapter_source: None, + adapter_parameters: None, api_token: None, best_of: req.n.map(|x| x as usize), temperature: req.temperature, diff --git a/router/src/loader.rs b/router/src/loader.rs index 74c39b690..7aba7d3cf 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -140,7 +140,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver Date: Sat, 27 Jan 2024 13:38:06 -0800 Subject: [PATCH 06/23] Convert json to proto --- router/client/src/lib.rs | 3 ++- router/client/src/sharded_client.rs | 3 +-- router/src/lib.rs | 18 ++++++++++++++++++ router/src/loader.rs | 20 ++++++++++---------- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index fdb43e67c..03e414c8e 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -10,7 +10,8 @@ pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, - NextTokenChooserParameters, PrefillTokens, Request, StoppingCriteriaParameters, + MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, + Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index f0068a1ab..c02967e2d 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,5 @@ -use crate::pb::generate::v1::AdapterParameters; /// Multi shard Client -use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo, AdapterParameters}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; diff --git a/router/src/lib.rs b/router/src/lib.rs index 698212f56..f07b9e627 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,8 @@ mod queue; mod scheduler; pub mod server; mod validation; +use lorax_client::AdapterParameters as AdapterParametersMessage; +use lorax_client::{MajoritySignMethod, MergeStrategy}; use infer::Infer; use loader::AdapterLoader; @@ -84,6 +86,22 @@ pub(crate) struct AdapterParameters { pub majority_sign_method: Option, } +impl Into for AdapterParameters { + fn into(self) -> AdapterParametersMessage { + AdapterParametersMessage { + adapter_ids: self.adapter_ids, + weights: self.weights, + merge_strategy: MergeStrategy::from_str_name( + self.merge_strategy.unwrap_or("linear".to_string()).as_str(), + ).unwrap().into(), + density: self.density, + majority_sign_method: MajoritySignMethod::from_str_name( + self.majority_sign_method.unwrap_or("total".to_string()).as_str(), + ).unwrap().into(), + } + } +} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] diff --git a/router/src/loader.rs b/router/src/loader.rs index 7aba7d3cf..7f7ac0b1b 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -140,14 +140,14 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} downloaded", adapter.id()); + tracing::info!("adapter {} downloaded", adapter.as_string()); let mut locked_state = queues_state.lock().unwrap(); if locked_state.has_adapter(&adapter) { // Above check guards against the case where the adapter was terminated between the initial @@ -157,7 +157,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED downloading adapter {}", adapter.id()); + tracing::info!("FAILED downloading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "download_adapter"); let mut locked_state = queues_state.lock().unwrap(); if locked_state.has_adapter(&adapter) { @@ -186,7 +186,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} loaded", adapter.id()); + tracing::info!("adapter {} loaded", adapter.as_string()); queues_state .lock() .unwrap() @@ -203,7 +203,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED loading adapter {}", adapter.id()); + tracing::info!("FAILED loading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "load_adapter"); queues_state .lock() @@ -231,14 +231,14 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} offloaded", adapter.id()); + tracing::info!("adapter {} offloaded", adapter.as_string()); queues_state .lock() .unwrap() @@ -247,7 +247,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED offloading adapter {}", adapter.id()); + tracing::info!("FAILED offloading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "offload_adapter"); queues_state .lock() @@ -273,7 +273,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("terminating adapter {} loader", adapter.id()); + tracing::info!("terminating adapter {} loader", adapter.as_string()); let mut locked_state = queues_state.lock().unwrap(); if !locked_state.has_adapter(&adapter) { From 45ecf88c8b2718ddfc7e912f7985ef13f8829414 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 27 Jan 2024 13:50:41 -0800 Subject: [PATCH 07/23] Validation --- router/src/lib.rs | 4 ++-- router/src/validation.rs | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index f07b9e627..9a997f7b6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -69,7 +69,7 @@ pub struct Info { #[derive(Clone, Debug, Deserialize, ToSchema, Default)] pub(crate) struct AdapterParameters { - #[serde(default)] + #[serde(rename(deserialize = "ids"))] #[schema(inline, example = json ! (["arnavgrg/codealpaca-qlora"]))] pub adapter_ids: Vec, #[serde(default)] @@ -114,7 +114,7 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, - #[serde(default)] + #[serde(rename(deserialize = "adapters"))] #[schema(nullable = true, default = "null")] pub adapter_parameters: Option, #[serde(default)] diff --git a/router/src/validation.rs b/router/src/validation.rs index d1fc8513d..2faaf8a7f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -144,11 +144,30 @@ impl Validation { seed, watermark, adapter_id, + adapter_parameters, decoder_input_details, apply_chat_template, .. } = request.parameters; + // adapter validation + // cannot specify both adapter_id and adapter_parameters + if adapter_parameters.is_some() && adapter_id.is_some() { + return Err(ValidationError::AdapterIdConflict); + } + + if adapter_parameters.is_some() { + let nadapters = adapter_parameters.as_ref().unwrap().adapter_ids.len(); + let nweights = adapter_parameters.as_ref().unwrap().weights.len(); + if nadapters < 1 { + return Err(ValidationError::AdapterIdMissing); + } + + if nadapters != nweights { + return Err(ValidationError::AdapterWeightMismatch); + } + } + // sampling must be true when best_of > 1 let best_of = best_of.unwrap_or(1); let sampling = do_sample @@ -389,6 +408,12 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("at most one of `adapter_id` or `adapters` may be provided")] + AdapterIdConflict, + #[error("at least one adapter ID must be provided when setting `adapters`")] + AdapterIdMissing, + #[error("number of adapter IDs must match number of adapter weights")] + AdapterWeightMismatch, } #[cfg(test)] From d4f3dc83621da62d3c3ca5c136137789b82e3808 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 29 Jan 2024 22:49:28 -0800 Subject: [PATCH 08/23] Plumb through adapter parameters --- router/src/infer.rs | 23 ++++++++++++++++------- router/src/lib.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index b8b4b656c..732d0c408 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,7 @@ use crate::adapter::{Adapter, BASE_MODEL_ADAPTER_ID, DEFAULT_ADAPTER_SOURCE}; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{Entry, Token}; +use crate::{Entry, Token, AdapterParameters}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -32,7 +32,7 @@ pub struct Infer { /// Manages the queues of the various adapters adapter_scheduler: AdapterScheduler, /// Maps adapter ID to a unique index - adapter_to_index: Arc>>, + adapter_to_index: Arc>>, /// Inference limit limit_concurrent_requests: Arc, } @@ -69,7 +69,10 @@ impl Infer { ); // Initialize with base model adapter (empty) mapping to index 0 - let adapter_to_index = Arc::new(Mutex::new(HashMap::from([("".to_string(), 0)]))); + let adapter_to_index = Arc::new(Mutex::new(HashMap::from([(AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, 0)]))); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( @@ -126,21 +129,27 @@ impl Infer { adapter_source = Some(DEFAULT_ADAPTER_SOURCE.to_string()); } + let adapter_parameters = request.parameters.adapter_parameters.clone().unwrap_or(AdapterParameters { + adapter_ids: vec![adapter_id.clone().unwrap()], + ..Default::default() + }); + let adapter_idx; { // TODO(travis): can optimize concurrency here using RWLock let mut adapter_to_index = self.adapter_to_index.lock().await; - if adapter_to_index.contains_key(&adapter_id.clone().unwrap()) { - adapter_idx = *adapter_to_index.get(&adapter_id.clone().unwrap()).unwrap(); + let adapter_key = adapter_parameters.clone(); + if adapter_to_index.contains_key(&adapter_key) { + adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); } else { adapter_idx = adapter_to_index.len() as u32; - adapter_to_index.insert(adapter_id.clone().unwrap(), adapter_idx); + adapter_to_index.insert(adapter_key, adapter_idx); } } let api_token = request.parameters.api_token.clone(); let adapter = Adapter::new( - adapter_id.unwrap(), + adapter_parameters, adapter_source.unwrap(), adapter_idx, api_token, diff --git a/router/src/lib.rs b/router/src/lib.rs index 9a997f7b6..0d474968f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -102,6 +102,46 @@ impl Into for AdapterParameters { } } +impl std::hash::Hash for AdapterParameters { + fn hash(&self, state: &mut H) { + if self.adapter_ids.len() == 1 { + self.adapter_ids[0].hash(state); + return + } + + self.adapter_ids.hash(state); + + // Convert weights vec into vec of u32 bits + let weights: Vec = self.weights.iter().map(|x| x.to_bits()).collect(); + weights.hash(state); + + self.merge_strategy.hash(state); + + // Hash the raw bits of the float, acknowledging that this + // can cause issues with different representations of the same value. + self.density.to_bits().hash(state); + + self.majority_sign_method.hash(state); + } +} + +impl PartialEq for AdapterParameters { + fn eq(&self, other: &Self) -> bool { + if self.adapter_ids.len() == 1 { + return self.adapter_ids[0] == other.adapter_ids[0] + } + + // In this implementation, we assume that adapter order matters + self.adapter_ids == other.adapter_ids + && self.weights == other.weights + && self.merge_strategy == other.merge_strategy + && self.density == other.density // direct comparison of f32 + && self.majority_sign_method == other.majority_sign_method + } +} + +impl Eq for AdapterParameters {} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] From 2f2b45f6f29120328d9aa5e8d5c444f15003425c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 29 Jan 2024 22:49:50 -0800 Subject: [PATCH 09/23] cargo fmt --- router/src/adapter.rs | 7 ++++++- router/src/infer.rs | 26 +++++++++++++++++--------- router/src/lib.rs | 18 ++++++++++++------ router/src/validation.rs | 10 +++++----- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/router/src/adapter.rs b/router/src/adapter.rs index 2ae3286e5..4e34fe687 100644 --- a/router/src/adapter.rs +++ b/router/src/adapter.rs @@ -24,7 +24,12 @@ pub(crate) struct Adapter { } impl Adapter { - pub(crate) fn new(params: AdapterParameters, source: String, index: u32, api_token: Option) -> Self { + pub(crate) fn new( + params: AdapterParameters, + source: String, + index: u32, + api_token: Option, + ) -> Self { Self { params, source, diff --git a/router/src/infer.rs b/router/src/infer.rs index 732d0c408..48a9d503f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,7 @@ use crate::adapter::{Adapter, BASE_MODEL_ADAPTER_ID, DEFAULT_ADAPTER_SOURCE}; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{Entry, Token, AdapterParameters}; +use crate::{AdapterParameters, Entry, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -69,10 +69,13 @@ impl Infer { ); // Initialize with base model adapter (empty) mapping to index 0 - let adapter_to_index = Arc::new(Mutex::new(HashMap::from([(AdapterParameters { - adapter_ids: vec!["".to_string()], - ..Default::default() - }, 0)]))); + let adapter_to_index = Arc::new(Mutex::new(HashMap::from([( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + 0, + )]))); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( @@ -129,10 +132,15 @@ impl Infer { adapter_source = Some(DEFAULT_ADAPTER_SOURCE.to_string()); } - let adapter_parameters = request.parameters.adapter_parameters.clone().unwrap_or(AdapterParameters { - adapter_ids: vec![adapter_id.clone().unwrap()], - ..Default::default() - }); + let adapter_parameters = + request + .parameters + .adapter_parameters + .clone() + .unwrap_or(AdapterParameters { + adapter_ids: vec![adapter_id.clone().unwrap()], + ..Default::default() + }); let adapter_idx; { diff --git a/router/src/lib.rs b/router/src/lib.rs index 0d474968f..0dcd795a5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -93,11 +93,17 @@ impl Into for AdapterParameters { weights: self.weights, merge_strategy: MergeStrategy::from_str_name( self.merge_strategy.unwrap_or("linear".to_string()).as_str(), - ).unwrap().into(), + ) + .unwrap() + .into(), density: self.density, majority_sign_method: MajoritySignMethod::from_str_name( - self.majority_sign_method.unwrap_or("total".to_string()).as_str(), - ).unwrap().into(), + self.majority_sign_method + .unwrap_or("total".to_string()) + .as_str(), + ) + .unwrap() + .into(), } } } @@ -106,7 +112,7 @@ impl std::hash::Hash for AdapterParameters { fn hash(&self, state: &mut H) { if self.adapter_ids.len() == 1 { self.adapter_ids[0].hash(state); - return + return; } self.adapter_ids.hash(state); @@ -128,9 +134,9 @@ impl std::hash::Hash for AdapterParameters { impl PartialEq for AdapterParameters { fn eq(&self, other: &Self) -> bool { if self.adapter_ids.len() == 1 { - return self.adapter_ids[0] == other.adapter_ids[0] + return self.adapter_ids[0] == other.adapter_ids[0]; } - + // In this implementation, we assume that adapter order matters self.adapter_ids == other.adapter_ids && self.weights == other.weights diff --git a/router/src/validation.rs b/router/src/validation.rs index 2faaf8a7f..f10949c39 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -419,8 +419,8 @@ pub enum ValidationError { #[cfg(test)] mod tests { use super::*; - use crate::{default_parameters, AdapterParameters}; use crate::tests::get_tokenizer; + use crate::{default_parameters, AdapterParameters}; #[tokio::test] async fn test_validation_max_new_tokens() { @@ -503,7 +503,7 @@ mod tests { }, }, Adapter::new( - AdapterParameters{ + AdapterParameters { adapter_ids: vec!["".to_string()], ..Default::default() }, @@ -545,7 +545,7 @@ mod tests { }, }, Adapter::new( - AdapterParameters{ + AdapterParameters { adapter_ids: vec!["".to_string()], ..Default::default() }, @@ -571,7 +571,7 @@ mod tests { }, }, Adapter::new( - AdapterParameters{ + AdapterParameters { adapter_ids: vec!["".to_string()], ..Default::default() }, @@ -597,7 +597,7 @@ mod tests { }, }, Adapter::new( - AdapterParameters{ + AdapterParameters { adapter_ids: vec!["".to_string()], ..Default::default() }, From 81962f02906233ba4f2cf9b21aea48e17a96aaab Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 29 Jan 2024 23:17:53 -0800 Subject: [PATCH 10/23] Fixed circular import --- server/lorax_server/utils/merges/strategies.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 28a824542..4be9ebdcf 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -1,13 +1,15 @@ from abc import ABC from collections import defaultdict -from typing import Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type import torch from peft import LoraConfig from lorax_server.pb.generate_pb2 import AdapterParameters from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune -from lorax_server.utils.adapter import ModuleMap + +if TYPE_CHECKING: + from lorax_server.utils.adapter import ModuleMap def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor: @@ -93,9 +95,9 @@ def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: def merge_adapters( - adapters: List[Tuple[ModuleMap, LoraConfig]], + adapters: List[Tuple["ModuleMap", LoraConfig]], merge_params: AdapterParameters, -) -> Tuple[ModuleMap, LoraConfig]: +) -> Tuple["ModuleMap", LoraConfig]: strategy_name = merge_params.merge_strategy weights = merge_params.weights @@ -128,7 +130,7 @@ def merge_adapters( # merge tensors for each module such that we have a single ModuleMap: # dict[k] -> merged tensor - merged_module_map: ModuleMap = defaultdict(dict) + merged_module_map: "ModuleMap" = defaultdict(dict) for weight_name, data in module_maps.items(): for k, param_data in data.items(): for param_name, tensors in param_data.items(): From 32fc279e28b861f40c77a97907319f7b418155f5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 08:57:54 -0800 Subject: [PATCH 11/23] adapters -> merged_adapters --- router/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 0dcd795a5..83f00f36e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -160,7 +160,7 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, - #[serde(rename(deserialize = "adapters"))] + #[serde(rename(deserialize = "merged_adapters"))] #[schema(nullable = true, default = "null")] pub adapter_parameters: Option, #[serde(default)] From 80970c21f732da3671f91712d75867e9528c7af0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 09:43:19 -0800 Subject: [PATCH 12/23] To Uppercase --- router/src/lib.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 83f00f36e..19ab7a191 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -92,7 +92,10 @@ impl Into for AdapterParameters { adapter_ids: self.adapter_ids, weights: self.weights, merge_strategy: MergeStrategy::from_str_name( - self.merge_strategy.unwrap_or("linear".to_string()).as_str(), + self.merge_strategy + .unwrap_or("linear".to_string()) + .to_uppercase() + .as_str(), ) .unwrap() .into(), @@ -100,6 +103,7 @@ impl Into for AdapterParameters { majority_sign_method: MajoritySignMethod::from_str_name( self.majority_sign_method .unwrap_or("total".to_string()) + .to_uppercase() .as_str(), ) .unwrap() From 5d7c0872fff967ea2ec747a468a508881223f747 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 09:58:13 -0800 Subject: [PATCH 13/23] Update python client --- clients/python/lorax/__init__.py | 4 +-- clients/python/lorax/client.py | 17 +++++++++++ clients/python/lorax/types.py | 51 ++++++++++++++++++++++++++++++++ clients/python/pyproject.toml | 2 +- 4 files changed, 71 insertions(+), 3 deletions(-) diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index e0c800858..b95c7ffbe 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.1" +__version__ = "0.3.0" -from lorax.client import Client, AsyncClient +from lorax.client import Client, AsyncClient, MergedAdapters diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 190f004ba..89fb956a0 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -10,6 +10,7 @@ Response, Request, Parameters, + MergedAdapters, ) from lorax.errors import parse_error @@ -63,6 +64,7 @@ def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -89,6 +91,8 @@ def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -130,6 +134,7 @@ def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, details=True, @@ -166,6 +171,7 @@ def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -190,6 +196,8 @@ def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -227,6 +235,7 @@ def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=None, details=True, @@ -329,6 +338,7 @@ async def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -355,6 +365,8 @@ async def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -396,6 +408,7 @@ async def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, details=True, @@ -430,6 +443,7 @@ async def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -454,6 +468,8 @@ async def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -491,6 +507,7 @@ async def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=None, details=True, diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index fe880f5f5..ad6dc3ab1 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -6,6 +6,49 @@ ADAPTER_SOURCES = ["hub", "local", "s3", "pbase"] +MERGE_STRATEGIES = ["linear", "ties", "dare_linear", "dare_ties"] +MAJORITY_SIGN_METHODS = ["total", "frequency"] + + +class MergedAdapters(BaseModel): + # IDs of the adapters to merge + ids: List[str] + # Weights of the adapters to merge + weights: List[float] + # Merge strategy + merge_strategy: Optional[str] + # Density + density: float + # Majority sign method + majority_sign_method: Optional[str] + + @validator("ids", "weights") + def validate_ids_weights(cls, ids, weights): + if not ids: + raise ValidationError("`ids` cannot be empty") + if not weights: + raise ValidationError("`weights` cannot be empty") + if len(ids) != len(weights): + raise ValidationError("`ids` and `weights` must have the same length") + return ids, weights + + @validator("merge_strategy") + def validate_merge_strategy(cls, v): + if v is not None and v not in MERGE_STRATEGIES: + raise ValidationError(f"`merge_strategy` must be one of {MERGE_STRATEGIES}") + return v + + @validator("density") + def validate_density(cls, v): + if v < 0 or v > 1.0: + raise ValidationError("`density` must be >= 0.0 and <= 1.0") + return v + + @validator("majority_sign_method") + def validate_majority_sign_method(cls, v): + if v is not None and v not in MAJORITY_SIGN_METHODS: + raise ValidationError(f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}") + return v class Parameters(BaseModel): @@ -13,6 +56,8 @@ class Parameters(BaseModel): adapter_id: Optional[str] # The source of the adapter to use adapter_source: Optional[str] + # Adapter merge parameters + merged_adapters: Optional[MergedAdapters] # API token for accessing private adapters api_token: Optional[str] # Activate logits sampling @@ -49,6 +94,12 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False + @validator("adapter_id", "merged_adapters") + def valid_adapter_id_merged_adapters(cls, adapter_id, merged_adapters): + if adapter_id is not None and merged_adapters is not None: + raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") + return adapter_id, merged_adapters + @validator("adapter_source") def valid_adapter_source(cls, v): if v is not None and v not in ADAPTER_SOURCES: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index a4b284764..8a5dadb4c 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lorax-client" packages = [ {include = "lorax"} ] -version = "0.2.1" +version = "0.3.0" description = "LoRAX Python Client" license = "Apache-2.0" authors = ["Travis Addair ", "Olivier Dehaene "] From 8e9feb1c10322ce0f0ed1c0645fd5072f9c9630b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 10:08:06 -0800 Subject: [PATCH 14/23] Fixed validation --- clients/python/lorax/types.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index ad6dc3ab1..a34f0e612 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -22,15 +22,20 @@ class MergedAdapters(BaseModel): # Majority sign method majority_sign_method: Optional[str] - @validator("ids", "weights") - def validate_ids_weights(cls, ids, weights): - if not ids: + @validator("ids") + def validate_ids(cls, v): + if not v: raise ValidationError("`ids` cannot be empty") - if not weights: + return v + + @validator("weights") + def validate_weights(cls, v, values): + ids = values["ids"] + if not v: raise ValidationError("`weights` cannot be empty") - if len(ids) != len(weights): + if len(ids) != len(v): raise ValidationError("`ids` and `weights` must have the same length") - return ids, weights + return v @validator("merge_strategy") def validate_merge_strategy(cls, v): @@ -94,11 +99,12 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False - @validator("adapter_id", "merged_adapters") - def valid_adapter_id_merged_adapters(cls, adapter_id, merged_adapters): - if adapter_id is not None and merged_adapters is not None: + @validator("adapter_id") + def valid_adapter_id(cls, v, values): + merged_adapters = values.get("merged_adapters") + if v is not None and merged_adapters is not None: raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") - return adapter_id, merged_adapters + return v @validator("adapter_source") def valid_adapter_source(cls, v): From ae03b76972e3c1037a2b18750511a6367590b11a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 10:31:29 -0800 Subject: [PATCH 15/23] Fixed server --- server/lorax_server/server.py | 10 +++++++--- server/lorax_server/utils/adapter.py | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 4d381e720..066120f3d 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -177,14 +177,18 @@ async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): return generate_pb2.LoadAdapterResponse(loaded=False) try: - adapter_id = request.adapter_id adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index api_token = request.api_token + if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + for i in range(len(adapter_parameters.adapter_ids)): + adapter_id = adapter_parameters.adapter_ids[i] + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + adapter_parameters.adapter_ids[i] = adapter_id adapter_source = S3 - self.model.load_adapter(adapter_id, adapter_source, adapter_index, api_token) + + self.model.load_adapter(adapter_parameters, adapter_source, adapter_index, api_token) return generate_pb2.LoadAdapterResponse(loaded=True) except Exception: diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 993fd691e..94af07185 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -70,7 +70,7 @@ def _load_and_merge( params = adapter_params.adapter_parameters adapters_to_merge = [] - weight_names = set() + merged_weight_names = set() tokenizer = None for adapter_id in params.adapter_ids: if adapter_id == BASE_MODEL_ADAPTER_ID: @@ -81,7 +81,7 @@ def _load_and_merge( ) adapters_to_merge.append((module_map, adapter_config)) - weight_names = weight_names.union(adapter_weight_names) + merged_weight_names = merged_weight_names.union(adapter_weight_names) if tokenizer is None: tokenizer = adapter_tokenizer @@ -89,7 +89,7 @@ def _load_and_merge( raise ValueError("No adapters to merge.") module_map, adapter_config = merge_adapters(adapters_to_merge, params) - return module_map, adapter_config, weight_names, tokenizer + return module_map, adapter_config, merged_weight_names, tokenizer @lru_cache(maxsize=128) From 390a0ace0375ffa6bfcdfa47c5a171eef5474c5e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 15:11:59 -0800 Subject: [PATCH 16/23] Fixed enum --- server/lorax_server/utils/merges/strategies.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 4be9ebdcf..babcd2f25 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -5,7 +5,11 @@ import torch from peft import LoraConfig -from lorax_server.pb.generate_pb2 import AdapterParameters +from lorax_server.pb.generate_pb2 import ( + AdapterParameters, + MajoritySignMethod as MajoritySignMethodEnum, + MergeStrategy as MergeStrategyEnum, +) from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune if TYPE_CHECKING: @@ -98,7 +102,7 @@ def merge_adapters( adapters: List[Tuple["ModuleMap", LoraConfig]], merge_params: AdapterParameters, ) -> Tuple["ModuleMap", LoraConfig]: - strategy_name = merge_params.merge_strategy + strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() weights = merge_params.weights if not weights: @@ -108,7 +112,7 @@ def merge_adapters( merge_config = { "density": merge_params.density, - "majority_sign_method": merge_params.majority_sign_method, + "majority_sign_method": MajoritySignMethodEnum.Name(merge_params.majority_sign_method).lower(), } merge_strategy = strategy_registry[strategy_name](weights=weights, **merge_config) @@ -145,12 +149,12 @@ def merge_adapters( def _validate_lora_configs(lora_configs: List[LoraConfig]): # check that all configs have the same rank - ranks = set(lora_config.rank for lora_config in lora_configs) + ranks = set(lora_config.r for lora_config in lora_configs) if len(ranks) > 1: raise ValueError(f"unable to merge adapters, lora configs have different ranks: {ranks}") # check that all configs have the same target modules - target_modules = set(" | ".join(lora_config.target_modules for lora_config in lora_configs)) + target_modules = set([" | ".join(sorted(lora_config.target_modules)) for lora_config in lora_configs]) if len(target_modules) > 1: raise ValueError(f"unable to merge adapters, lora configs have different target modules: {target_modules}") From 55bae32ca5473bbfe5604719ce2800471ce0a707 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 16:43:26 -0800 Subject: [PATCH 17/23] Select weights compatible with adapters --- .../lorax_server/utils/merges/strategies.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index babcd2f25..114450997 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -28,29 +28,28 @@ def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor class MergeStrategy(ABC): - def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: raise NotImplementedError() class LinearMerge(MergeStrategy): - def __init__(self, weights: torch.Tensor, **kwargs): - self.weights = weights + def __init__(self, **kwargs): + pass - def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: - weighted_task_tensors = _apply_weights(task_tensors, self.weights) + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) return weighted_task_tensors.sum(dim=0) class TiesMerge(MergeStrategy): - def __init__(self, weights: torch.Tensor, density: float, majority_sign_method: str = "total", **kwargs): - self.weights = weights + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): self.density = density self.majority_sign_method = majority_sign_method - def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, self.weights) + weighted_task_tensors = _apply_weights(task_tensors, weights) # elect sign majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) @@ -60,27 +59,25 @@ def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: class DareLinearMerge(MergeStrategy): - def __init__(self, weights: torch.Tensor, density: float, **kwargs): - self.weights = weights + def __init__(self, density: float, **kwargs): self.density = density - def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, self.weights) + weighted_task_tensors = _apply_weights(task_tensors, weights) return weighted_task_tensors.sum(dim=0) class DareTiesMerge(MergeStrategy): - def __init__(self, weights: torch.Tensor, density: float, majority_sign_method: str = "total", **kwargs): - self.weights = weights + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): self.density = density self.majority_sign_method = majority_sign_method - def merge(self, task_tensors: List[torch.Tensor]) -> torch.Tensor: + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] - weighted_task_tensors = _apply_weights(task_tensors, self.weights) + weighted_task_tensors = _apply_weights(task_tensors, weights) # elect sign majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) @@ -114,17 +111,19 @@ def merge_adapters( "density": merge_params.density, "majority_sign_method": MajoritySignMethodEnum.Name(merge_params.majority_sign_method).lower(), } - merge_strategy = strategy_registry[strategy_name](weights=weights, **merge_config) + merge_strategy = strategy_registry[strategy_name](**merge_config) module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(list)) ) lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) # input is list of (module_map, lora_config) tuples # convert into dict[k][param_name] -> list of tensors - for module_map, lora_config in adapters: + for idx, (module_map, lora_config) in enumerate(adapters): for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) for k, (param_data, param_name) in data.items(): module_maps[weight_name][k][param_name].append(param_data) lora_configs.append(lora_config) @@ -136,9 +135,11 @@ def merge_adapters( # dict[k] -> merged tensor merged_module_map: "ModuleMap" = defaultdict(dict) for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] for k, param_data in data.items(): for param_name, tensors in param_data.items(): - merged_tensor = merge_strategy.merge(tensors) + merged_tensor = merge_strategy.merge(tensors, param_weights) merged_module_map[weight_name][k] = (merged_tensor, param_name) # merge lora configs From 60f1ee32daa43eaa132031d28017b3dae9cb7ab1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jan 2024 16:54:45 -0800 Subject: [PATCH 18/23] Fixed target module merging --- .../lorax_server/utils/merges/strategies.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 114450997..99dd5d81a 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -1,5 +1,6 @@ from abc import ABC from collections import defaultdict +import copy from typing import TYPE_CHECKING, Dict, List, Tuple, Type import torch @@ -153,14 +154,18 @@ def _validate_lora_configs(lora_configs: List[LoraConfig]): ranks = set(lora_config.r for lora_config in lora_configs) if len(ranks) > 1: raise ValueError(f"unable to merge adapters, lora configs have different ranks: {ranks}") - - # check that all configs have the same target modules - target_modules = set([" | ".join(sorted(lora_config.target_modules)) for lora_config in lora_configs]) - if len(target_modules) > 1: - raise ValueError(f"unable to merge adapters, lora configs have different target modules: {target_modules}") + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError("unable to merge adapters, lora configs have no target modules") def _merge_lora_configs(lora_configs: List[LoraConfig]) -> LoraConfig: - # for now, due to rank and target constraints, we can just return one config - # may revisit this in the future if we loosen these constraints - return lora_configs[0] + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted(set( + module for lora_config in lora_configs for module in lora_config.target_modules + )) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config From 14960cccbd788682779671792fcaa3d50af0a5cf Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 31 Jan 2024 15:27:21 -0800 Subject: [PATCH 19/23] Added docs --- docs/guides/mixing_adapters.md | 82 +++++++++++++++++++++++ docs/models/adapters.md | 9 +++ mkdocs.yml | 1 + router/src/queue.rs | 3 + server/lorax_server/models/flash_llama.py | 3 +- 5 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 docs/guides/mixing_adapters.md diff --git a/docs/guides/mixing_adapters.md b/docs/guides/mixing_adapters.md new file mode 100644 index 000000000..7a48beaef --- /dev/null +++ b/docs/guides/mixing_adapters.md @@ -0,0 +1,82 @@ +# Mixing Adapters + +In LoRAX, multiple LoRA adapters can be mixed together per request to create powerful multi-task ensembles by merging the individual +adapter weights together using a given [merge strategy](#merge-strategies). + +This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without +requiring them to specify the type of task they wish to perform. + +## Background: Model Merging + +Model merging is a set of techniques popularized by frameworks like [mergekit](https://github.com/cg123/mergekit) that allow taking +multiple specialized fine-tuned models and combining their weights together to output a single model that can perform each of these +tasks with a much smaller total footprint. + +A common use case could be to train specialized LoRA adapters for tasks like SQL generation, customer support email +generation, and information extraction. Without model merging, the user submitting their query will need to know in advance which +of these models to route their query to. With model merging, the user should be able to submit their query without prior knowledge +of which backing adapter is best suited to respond to the query. + +In some cases the mixing of adapter specializations could even result in a better final response. For example, by mixing an adapter that understand math with an adapter that can provide detailed and intuitive explanations, the user could in theory get correct answers to math questions with detailed step-by-step reasoning to aide in the user's learning. + +## Merge Strategies + +LoRAX provides a number of model merging methods taken from [mergekit](https://github.com/cg123/mergekit) and [PEFT](https://github.com/huggingface/peft). + +Options: + +- `linear` (default) +- `ties` +- `dare_linear` +- `dare_ties` + +### Linear + +The default and most straightforward way to merge model adapters is to linearly combine each of the parameters as a weighted average. This idea was +explored in the context of merging fine-tuned models in [Model Soups](https://arxiv.org/abs/2203.05482). + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. + +### TIES + +[TIES](https://arxiv.org/abs/2306.01708) is based on the idea of [Task Arithmetic](https://arxiv.org/abs/2212.04089), whereby the fine-tuned models +are merged after subtracting out the base model weights. LoRA and other adapters are already task-specific tensors, +so this approach is a natural fit when merging LoRAs. + +To resolve interference between adapters, the weights are sparsified and a sign-based consensus algorithms is used to determine the weighted average. + +One the strengths of this approach is its ability to scale well to large numbers of adapters and retain each of their strengths. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. +- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus. + +### DARE (Linear) + +[DARE](https://arxiv.org/abs/2311.03099), like TIES, sparsifies adapter weights (task vectors) to reduce interference. Unlike TIES, however, +DARE uses random pruning and rescaling in an attempt to better match performance of the independent adapters. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. + +### DARE (TIES) + +DARE method from above that also applies the sign consensus algorithm from TIES. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. +- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus. + +## Usage + +### Python Client + +### REST diff --git a/docs/models/adapters.md b/docs/models/adapters.md index 3597723b7..2f654ee76 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters.md @@ -141,6 +141,15 @@ Usage: } ``` +## Mixture of Adapters + +Multiple adapters can be mixed / merged together per request to create powerful ensembles of different specialized adapters. + +This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without +requiring them to specify the type of task they wish to perform. + +See [Mixing Adapters](../guides/mixing_adapters.md) for details. + ## Private Adapter Repositories For hosted adapter repositories like HuggingFace Hub and [Predibase](https://predibase.com/), you can perform inference using private adapters per request. diff --git a/mkdocs.yml b/mkdocs.yml index 25bbb95f7..2ab1839b5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,7 @@ nav: - OpenAI Compatible API: guides/openai_api.md - Quantization: guides/quantization.md - CUDA Graph Compilation: guides/cuda_graphs.md + - Mixing Adapters: guides/mixing_adapters.md # - GPUs: guides/gpus.md # - Fine-Tuning: guides/fine_tuning.md # - Quantization: guides/quantization.md diff --git a/router/src/queue.rs b/router/src/queue.rs index c4c6a5cad..4d1cf5d78 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -357,6 +357,9 @@ impl AdapterQueuesState { adapters_to_remove.insert(adapter.clone()); // Start async offload process + // TODO(travis): we're being too aggressive about offloading here, we should only + // add adapters to this set if the number of active adapters is full and there are new adapters + // waiting to be loaded offload_adapters.push(adapter.clone()); } } diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 2cb572289..07632613b 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -24,7 +24,8 @@ tracer = trace.get_tracer(__name__) -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] +# TODO(travis): re-enable LM_HEAD after resolving issues with outputs +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} From 0fdc32ff1fb7cc6e73c88c0635b7c6efda1d6d6a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 31 Jan 2024 20:56:16 -0800 Subject: [PATCH 20/23] fmt --- router/client/src/client.rs | 9 ++++++--- router/client/src/lib.rs | 4 ++-- router/client/src/sharded_client.rs | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 36608f01e..d1922d90e 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -200,7 +200,8 @@ impl Client { } else { let err_string = format!( "Invalid source '{}' when downloading adapter '{}'", - adapter_source, adapter_parameters.adapter_ids.join(" , ") + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -230,7 +231,8 @@ impl Client { } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_parameters.adapter_ids.join(" , ") + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -258,7 +260,8 @@ impl Client { } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_parameters.adapter_ids.join(" , ") + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 03e414c8e..1c5ab6555 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -10,8 +10,8 @@ pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, - MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, - Request, StoppingCriteriaParameters, + MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index c02967e2d..8a4bb9f63 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ /// Multi shard Client -use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo, AdapterParameters}; +use crate::{AdapterParameters, Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; From ace22275d43f9ee334c2f8018e64269dec2055af Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 31 Jan 2024 21:06:04 -0800 Subject: [PATCH 21/23] Example --- docs/guides/mixing_adapters.md | 55 ++++++++++++++++++++++++++++++---- router/client/src/lib.rs | 2 +- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/docs/guides/mixing_adapters.md b/docs/guides/mixing_adapters.md index 7a48beaef..bbe717b4d 100644 --- a/docs/guides/mixing_adapters.md +++ b/docs/guides/mixing_adapters.md @@ -75,8 +75,53 @@ Parameters: - `density` (required): fraction of weights in adapters to retain. - `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus. -## Usage - -### Python Client - -### REST +## Example + +This example is derived from the [PEFT example](https://github.com/huggingface/peft/blob/smangrul/add-new-merging-methods/examples/multi_adapter_examples/Lora_Merging.ipynb) for model merging. + +First deploy LoRAX using the base model `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`, then run the following using +the [LoRAX Python Client](../reference/python_client.md): + +```python +from lorax import Client + +client = Client(endpoint_url) + +# tinyllama merge +merged_adapters = MergedAdapters( + ids=[ + "smangrul/tinyllama_lora_norobots", + "smangrul/tinyllama_lora_sql", + "smangrul/tinyllama_lora_adcopy", + ], + weights=[2.0, 0.3, 0.7], + merge_strategy="ties", + density=0.2, + majority_sign_method="total", +) + +# norobots +prompt = """<|im_start|>user +Write an essay about Generative AI.<|im_end|> +<|im_start|>assistant \n""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) + +# adcopy +prompt = """<|im_start|>system +Create a text ad given the following product and description.<|im_end|> +<|im_start|>user +Product: Sony PS5 PlayStation Console +Description: The PS5™ console unleashes new gaming possibilities that you never anticipated.<|im_end|> +<|im_start|>assistant \n""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) + +# sql +prompt = """ Table: 2-11365528-2 +Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location'] +Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic? +SQL Query:""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) +``` diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 1c5ab6555..3aa4e10ae 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,7 +9,7 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, + AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, Request, StoppingCriteriaParameters, }; From 40213f7f5e7b91c0e96659e09b043e83e22386b9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 31 Jan 2024 21:16:07 -0800 Subject: [PATCH 22/23] OpenAPI --- ...mixing_adapters.md => merging_adapters.md} | 6 +- docs/models/adapters.md | 4 +- docs/reference/openapi.json | 58 +++++++++++++++++++ mkdocs.yml | 2 +- 4 files changed, 64 insertions(+), 6 deletions(-) rename docs/guides/{mixing_adapters.md => merging_adapters.md} (95%) diff --git a/docs/guides/mixing_adapters.md b/docs/guides/merging_adapters.md similarity index 95% rename from docs/guides/mixing_adapters.md rename to docs/guides/merging_adapters.md index bbe717b4d..f95207b61 100644 --- a/docs/guides/mixing_adapters.md +++ b/docs/guides/merging_adapters.md @@ -1,7 +1,7 @@ -# Mixing Adapters +# Merging Adapters -In LoRAX, multiple LoRA adapters can be mixed together per request to create powerful multi-task ensembles by merging the individual -adapter weights together using a given [merge strategy](#merge-strategies). +In LoRAX, multiple LoRA adapters can be merged together per request to create powerful multi-task ensembles +using one of several different [merge strategies](#merge-strategies). This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without requiring them to specify the type of task they wish to perform. diff --git a/docs/models/adapters.md b/docs/models/adapters.md index 2f654ee76..0a123aa86 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters.md @@ -141,14 +141,14 @@ Usage: } ``` -## Mixture of Adapters +## Merging Adapters Multiple adapters can be mixed / merged together per request to create powerful ensembles of different specialized adapters. This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without requiring them to specify the type of task they wish to perform. -See [Mixing Adapters](../guides/mixing_adapters.md) for details. +See [Merging Adapters](../guides/merging_adapters.md) for details. ## Private Adapter Repositories diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index b0ec58adb..3417b409b 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -663,6 +663,61 @@ "stop_sequence" ] }, + "AdapterParameters": { + "type": "object", + "properties": { + "ids": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "adapter1", + "adapter2" + ] + }, + "weights": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "example": [ + 0.5, + 0.5 + ] + }, + "merge_strategy": { + "type": "string", + "enum": [ + "linear", + "ties", + "dare_linear", + "dare_ties" + ], + "default": "linear", + "example": "ties" + }, + "density": { + "type": "number", + "format": "float", + "default": 0.0, + "example": 0.5, + "nullable": false, + "minimum": 0.0, + "maximim": 1.0 + }, + "majority_sign_method": { + "type": "string", + "enum": [ + "total", + "frequency" + ], + "default": "total", + "example": "total" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -782,6 +837,9 @@ "type": "string", "nullable": true }, + "merged_adapters": { + "$ref": "#/components/schemas/AdapterParameters" + }, "api_token": { "type": "string", "nullable": true diff --git a/mkdocs.yml b/mkdocs.yml index 2ab1839b5..e04dd3034 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,7 +46,7 @@ nav: - OpenAI Compatible API: guides/openai_api.md - Quantization: guides/quantization.md - CUDA Graph Compilation: guides/cuda_graphs.md - - Mixing Adapters: guides/mixing_adapters.md + - Merging Adapters: guides/merging_adapters.md # - GPUs: guides/gpus.md # - Fine-Tuning: guides/fine_tuning.md # - Quantization: guides/quantization.md From 5b3202d5a94061c6c3605e8791f289b1cd407da1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 31 Jan 2024 21:18:42 -0800 Subject: [PATCH 23/23] Fixed import --- docs/guides/merging_adapters.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/merging_adapters.md b/docs/guides/merging_adapters.md index f95207b61..0f41c4c75 100644 --- a/docs/guides/merging_adapters.md +++ b/docs/guides/merging_adapters.md @@ -83,7 +83,7 @@ First deploy LoRAX using the base model `TinyLlama/TinyLlama-1.1B-intermediate-s the [LoRAX Python Client](../reference/python_client.md): ```python -from lorax import Client +from lorax import Client, MergedAdapters client = Client(endpoint_url)