diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index e0c800858..b95c7ffbe 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.1" +__version__ = "0.3.0" -from lorax.client import Client, AsyncClient +from lorax.client import Client, AsyncClient, MergedAdapters diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 190f004ba..89fb956a0 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -10,6 +10,7 @@ Response, Request, Parameters, + MergedAdapters, ) from lorax.errors import parse_error @@ -63,6 +64,7 @@ def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -89,6 +91,8 @@ def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -130,6 +134,7 @@ def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, details=True, @@ -166,6 +171,7 @@ def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -190,6 +196,8 @@ def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -227,6 +235,7 @@ def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=None, details=True, @@ -329,6 +338,7 @@ async def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -355,6 +365,8 @@ async def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -396,6 +408,7 @@ async def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, details=True, @@ -430,6 +443,7 @@ async def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, @@ -454,6 +468,8 @@ async def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + merged_adapters (`Optional[MergedAdapters]`): + Merged adapters to apply to the base model for the request api_token (`Optional[str]`): API token for accessing private adapters do_sample (`bool`): @@ -491,6 +507,7 @@ async def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + merged_adapters=merged_adapters, api_token=api_token, best_of=None, details=True, diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index fe880f5f5..a34f0e612 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -6,6 +6,54 @@ ADAPTER_SOURCES = ["hub", "local", "s3", "pbase"] +MERGE_STRATEGIES = ["linear", "ties", "dare_linear", "dare_ties"] +MAJORITY_SIGN_METHODS = ["total", "frequency"] + + +class MergedAdapters(BaseModel): + # IDs of the adapters to merge + ids: List[str] + # Weights of the adapters to merge + weights: List[float] + # Merge strategy + merge_strategy: Optional[str] + # Density + density: float + # Majority sign method + majority_sign_method: Optional[str] + + @validator("ids") + def validate_ids(cls, v): + if not v: + raise ValidationError("`ids` cannot be empty") + return v + + @validator("weights") + def validate_weights(cls, v, values): + ids = values["ids"] + if not v: + raise ValidationError("`weights` cannot be empty") + if len(ids) != len(v): + raise ValidationError("`ids` and `weights` must have the same length") + return v + + @validator("merge_strategy") + def validate_merge_strategy(cls, v): + if v is not None and v not in MERGE_STRATEGIES: + raise ValidationError(f"`merge_strategy` must be one of {MERGE_STRATEGIES}") + return v + + @validator("density") + def validate_density(cls, v): + if v < 0 or v > 1.0: + raise ValidationError("`density` must be >= 0.0 and <= 1.0") + return v + + @validator("majority_sign_method") + def validate_majority_sign_method(cls, v): + if v is not None and v not in MAJORITY_SIGN_METHODS: + raise ValidationError(f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}") + return v class Parameters(BaseModel): @@ -13,6 +61,8 @@ class Parameters(BaseModel): adapter_id: Optional[str] # The source of the adapter to use adapter_source: Optional[str] + # Adapter merge parameters + merged_adapters: Optional[MergedAdapters] # API token for accessing private adapters api_token: Optional[str] # Activate logits sampling @@ -49,6 +99,13 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False + @validator("adapter_id") + def valid_adapter_id(cls, v, values): + merged_adapters = values.get("merged_adapters") + if v is not None and merged_adapters is not None: + raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") + return v + @validator("adapter_source") def valid_adapter_source(cls, v): if v is not None and v not in ADAPTER_SOURCES: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index a4b284764..8a5dadb4c 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lorax-client" packages = [ {include = "lorax"} ] -version = "0.2.1" +version = "0.3.0" description = "LoRAX Python Client" license = "Apache-2.0" authors = ["Travis Addair ", "Olivier Dehaene "] diff --git a/docs/guides/merging_adapters.md b/docs/guides/merging_adapters.md new file mode 100644 index 000000000..0f41c4c75 --- /dev/null +++ b/docs/guides/merging_adapters.md @@ -0,0 +1,127 @@ +# Merging Adapters + +In LoRAX, multiple LoRA adapters can be merged together per request to create powerful multi-task ensembles +using one of several different [merge strategies](#merge-strategies). + +This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without +requiring them to specify the type of task they wish to perform. + +## Background: Model Merging + +Model merging is a set of techniques popularized by frameworks like [mergekit](https://github.com/cg123/mergekit) that allow taking +multiple specialized fine-tuned models and combining their weights together to output a single model that can perform each of these +tasks with a much smaller total footprint. + +A common use case could be to train specialized LoRA adapters for tasks like SQL generation, customer support email +generation, and information extraction. Without model merging, the user submitting their query will need to know in advance which +of these models to route their query to. With model merging, the user should be able to submit their query without prior knowledge +of which backing adapter is best suited to respond to the query. + +In some cases the mixing of adapter specializations could even result in a better final response. For example, by mixing an adapter that understand math with an adapter that can provide detailed and intuitive explanations, the user could in theory get correct answers to math questions with detailed step-by-step reasoning to aide in the user's learning. + +## Merge Strategies + +LoRAX provides a number of model merging methods taken from [mergekit](https://github.com/cg123/mergekit) and [PEFT](https://github.com/huggingface/peft). + +Options: + +- `linear` (default) +- `ties` +- `dare_linear` +- `dare_ties` + +### Linear + +The default and most straightforward way to merge model adapters is to linearly combine each of the parameters as a weighted average. This idea was +explored in the context of merging fine-tuned models in [Model Soups](https://arxiv.org/abs/2203.05482). + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. + +### TIES + +[TIES](https://arxiv.org/abs/2306.01708) is based on the idea of [Task Arithmetic](https://arxiv.org/abs/2212.04089), whereby the fine-tuned models +are merged after subtracting out the base model weights. LoRA and other adapters are already task-specific tensors, +so this approach is a natural fit when merging LoRAs. + +To resolve interference between adapters, the weights are sparsified and a sign-based consensus algorithms is used to determine the weighted average. + +One the strengths of this approach is its ability to scale well to large numbers of adapters and retain each of their strengths. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. +- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus. + +### DARE (Linear) + +[DARE](https://arxiv.org/abs/2311.03099), like TIES, sparsifies adapter weights (task vectors) to reduce interference. Unlike TIES, however, +DARE uses random pruning and rescaling in an attempt to better match performance of the independent adapters. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. + +### DARE (TIES) + +DARE method from above that also applies the sign consensus algorithm from TIES. + +Parameters: + +- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request. +- `density` (required): fraction of weights in adapters to retain. +- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus. + +## Example + +This example is derived from the [PEFT example](https://github.com/huggingface/peft/blob/smangrul/add-new-merging-methods/examples/multi_adapter_examples/Lora_Merging.ipynb) for model merging. + +First deploy LoRAX using the base model `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`, then run the following using +the [LoRAX Python Client](../reference/python_client.md): + +```python +from lorax import Client, MergedAdapters + +client = Client(endpoint_url) + +# tinyllama merge +merged_adapters = MergedAdapters( + ids=[ + "smangrul/tinyllama_lora_norobots", + "smangrul/tinyllama_lora_sql", + "smangrul/tinyllama_lora_adcopy", + ], + weights=[2.0, 0.3, 0.7], + merge_strategy="ties", + density=0.2, + majority_sign_method="total", +) + +# norobots +prompt = """<|im_start|>user +Write an essay about Generative AI.<|im_end|> +<|im_start|>assistant \n""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) + +# adcopy +prompt = """<|im_start|>system +Create a text ad given the following product and description.<|im_end|> +<|im_start|>user +Product: Sony PS5 PlayStation Console +Description: The PS5™ console unleashes new gaming possibilities that you never anticipated.<|im_end|> +<|im_start|>assistant \n""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) + +# sql +prompt = """ Table: 2-11365528-2 +Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location'] +Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic? +SQL Query:""" +response = client.generate(prompt, merged_adapters=merged_adapters) +print(response.generated_text) +``` diff --git a/docs/models/adapters.md b/docs/models/adapters.md index 3597723b7..0a123aa86 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters.md @@ -141,6 +141,15 @@ Usage: } ``` +## Merging Adapters + +Multiple adapters can be mixed / merged together per request to create powerful ensembles of different specialized adapters. + +This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without +requiring them to specify the type of task they wish to perform. + +See [Merging Adapters](../guides/merging_adapters.md) for details. + ## Private Adapter Repositories For hosted adapter repositories like HuggingFace Hub and [Predibase](https://predibase.com/), you can perform inference using private adapters per request. diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index b0ec58adb..3417b409b 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -663,6 +663,61 @@ "stop_sequence" ] }, + "AdapterParameters": { + "type": "object", + "properties": { + "ids": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "adapter1", + "adapter2" + ] + }, + "weights": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "example": [ + 0.5, + 0.5 + ] + }, + "merge_strategy": { + "type": "string", + "enum": [ + "linear", + "ties", + "dare_linear", + "dare_ties" + ], + "default": "linear", + "example": "ties" + }, + "density": { + "type": "number", + "format": "float", + "default": 0.0, + "example": 0.5, + "nullable": false, + "minimum": 0.0, + "maximim": 1.0 + }, + "majority_sign_method": { + "type": "string", + "enum": [ + "total", + "frequency" + ], + "default": "total", + "example": "total" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -782,6 +837,9 @@ "type": "string", "nullable": true }, + "merged_adapters": { + "$ref": "#/components/schemas/AdapterParameters" + }, "api_token": { "type": "string", "nullable": true diff --git a/mkdocs.yml b/mkdocs.yml index 25bbb95f7..e04dd3034 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,7 @@ nav: - OpenAI Compatible API: guides/openai_api.md - Quantization: guides/quantization.md - CUDA Graph Compilation: guides/cuda_graphs.md + - Merging Adapters: guides/merging_adapters.md # - GPUs: guides/gpus.md # - Fine-Tuning: guides/fine_tuning.md # - Quantization: guides/quantization.md diff --git a/proto/generate.proto b/proto/generate.proto index e7afe4f26..1743b0b3a 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -232,9 +232,48 @@ enum AdapterSource { PBASE = 3; } +enum MergeStrategy { + /// Linear combination of adapters + LINEAR = 0; + + /// TIES method for combining adapters + TIES = 1; + + /// DARE method for combining adapters + DARE_LINEAR = 2; + + /// DARE + TIES method for combining adapters + DARE_TIES = 3; +} + +enum MajoritySignMethod { + /// Total method + TOTAL = 0; + + /// Frequency method + FREQUENCY = 1; +} + +message AdapterParameters { + /// Adapter IDs + repeated string adapter_ids = 1; + + /// Adapter weights for merging + repeated float weights = 2; + + /// Merge strategy (default: linear) + MergeStrategy merge_strategy = 3; + + /// [0, 1], 0: full pruning, 1: no pruning + float density = 4; + + /// Majority sign method (default: total) + MajoritySignMethod majority_sign_method = 5; +} + message DownloadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Token for external API (predibase / HuggingFace) @@ -242,15 +281,13 @@ message DownloadAdapterRequest { } message DownloadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; + /// True if download occurred, false if skipped + bool downloaded = 1; } message LoadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Adapter index @@ -260,17 +297,13 @@ message LoadAdapterRequest { } message LoadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; - /// Adapter index - uint32 adapter_index = 3; + /// True if load occurred, false if skipped + bool loaded = 1; } message OffloadAdapterRequest { - /// Adapter ID - string adapter_id = 1; + /// Adapter Parameters + AdapterParameters adapter_parameters = 1; /// Adapter source AdapterSource adapter_source = 2; /// Adapter index @@ -278,10 +311,6 @@ message OffloadAdapterRequest { } message OffloadAdapterResponse { - /// Adapter ID - string adapter_id = 1; - /// Adapter source - AdapterSource adapter_source = 2; - /// Adapter index - uint32 adapter_index = 3; + /// True if offload occurred, false if skipped + bool offloaded = 1; } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 68b000f09..d1922d90e 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -182,25 +182,26 @@ impl Client { /// Downloads the weights for an adapter. pub async fn download_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(DownloadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), api_token: api_token, }) .inject_context(); let response = self.stub.download_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.downloaded) } else { let err_string = format!( "Invalid source '{}' when downloading adapter '{}'", - adapter_source, adapter_id + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -210,27 +211,28 @@ impl Client { /// Physically loads the weights into the model for an adapter pub async fn load_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, api_token: Option, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(LoadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), adapter_index, api_token: api_token, }) .inject_context(); let response = self.stub.load_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.loaded) } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_id + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) @@ -240,25 +242,26 @@ impl Client { /// Offloads adapter the weights from GPU to CPU or disk pub async fn offload_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { let request = tonic::Request::new(OffloadAdapterRequest { - adapter_id, + adapter_parameters: Some(adapter_parameters), adapter_source: adapter_source_enum.into(), adapter_index, }) .inject_context(); let response = self.stub.offload_adapter(request).await?.into_inner(); - Ok(response.adapter_id) + Ok(response.offloaded) } else { let err_string = format!( "Invalid source '{}' when loading adapter '{}'", - adapter_source, adapter_id + adapter_source, + adapter_parameters.adapter_ids.join(",") ); tracing::error!(err_string); Err(ClientError::Generation(err_string).into()) diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 57e9b5e5f..3aa4e10ae 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,9 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - PrefillTokens, Request, StoppingCriteriaParameters, + AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, + MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 75bfd5c3f..8a4bb9f63 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ /// Multi shard Client -use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{AdapterParameters, Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; @@ -149,30 +149,30 @@ impl ShardedClient { pub async fn download_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { // Only download the adapter with one client, since they share a single disk self.clients[0] - .download_adapter(adapter_id, adapter_source, api_token) + .download_adapter(adapter_parameters, adapter_source, api_token) .await } pub async fn load_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, api_token: Option, - ) -> Result { + ) -> Result { // Load the adapter in all clients since there is sharding done between them let futures: Vec<_> = self .clients .iter_mut() .map(|client| { Box::pin(client.load_adapter( - adapter_id.clone(), + adapter_parameters.clone(), adapter_source.clone(), adapter_index, api_token.clone(), @@ -183,7 +183,7 @@ impl ShardedClient { match join_all(futures) .await .into_iter() - .collect::>>() + .collect::>>() { Ok(mut results) => { // Return the first adapter id @@ -195,17 +195,17 @@ impl ShardedClient { pub async fn offload_adapter( &mut self, - adapter_id: String, + adapter_parameters: AdapterParameters, adapter_source: String, adapter_index: u32, - ) -> Result { + ) -> Result { // Load the adapter in all clients since there is sharding done between them let futures: Vec<_> = self .clients .iter_mut() .map(|client| { Box::pin(client.offload_adapter( - adapter_id.clone(), + adapter_parameters.clone(), adapter_source.clone(), adapter_index, )) @@ -215,7 +215,7 @@ impl ShardedClient { match join_all(futures) .await .into_iter() - .collect::>>() + .collect::>>() { Ok(mut results) => { // Return the first adapter id diff --git a/router/src/adapter.rs b/router/src/adapter.rs index a7d81156f..4e34fe687 100644 --- a/router/src/adapter.rs +++ b/router/src/adapter.rs @@ -1,4 +1,6 @@ -/// Adapter utils +use std::hash; + +use crate::AdapterParameters; /// "adapter ID" for the base model. The base model does not have an adapter ID, /// but we reason about it in the same way. This must match the base model ID @@ -9,10 +11,10 @@ pub const BASE_MODEL_ADAPTER_ID: &str = "__base_model__"; /// from within the proto definition, or lib.rs pub const DEFAULT_ADAPTER_SOURCE: &str = "hub"; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub(crate) struct Adapter { - /// name of adapter - id: String, + /// adapter parameters + params: AdapterParameters, /// source (enforced at proto level) source: String, /// index of the adapter @@ -22,17 +24,22 @@ pub(crate) struct Adapter { } impl Adapter { - pub(crate) fn new(id: String, source: String, index: u32, api_token: Option) -> Self { + pub(crate) fn new( + params: AdapterParameters, + source: String, + index: u32, + api_token: Option, + ) -> Self { Self { - id, + params, source, index, api_token, } } - pub(crate) fn id(&self) -> &str { - &self.id + pub(crate) fn params(&self) -> &AdapterParameters { + &self.params } pub(crate) fn source(&self) -> &str { @@ -49,6 +56,20 @@ impl Adapter { pub(crate) fn as_string(&self) -> String { // format ":" - format!("{}:{}", self.source, self.id) + format!("{}:{}", self.source, self.params.adapter_ids.join(",")) + } +} + +impl hash::Hash for Adapter { + fn hash(&self, state: &mut H) { + self.index.hash(state); + } +} + +impl Eq for Adapter {} + +impl PartialEq for Adapter { + fn eq(&self, other: &Self) -> bool { + self.index == other.index } } diff --git a/router/src/infer.rs b/router/src/infer.rs index b8b4b656c..48a9d503f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,7 +3,7 @@ use crate::adapter::{Adapter, BASE_MODEL_ADAPTER_ID, DEFAULT_ADAPTER_SOURCE}; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; -use crate::{Entry, Token}; +use crate::{AdapterParameters, Entry, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; use flume::SendTimeoutError; @@ -32,7 +32,7 @@ pub struct Infer { /// Manages the queues of the various adapters adapter_scheduler: AdapterScheduler, /// Maps adapter ID to a unique index - adapter_to_index: Arc>>, + adapter_to_index: Arc>>, /// Inference limit limit_concurrent_requests: Arc, } @@ -69,7 +69,13 @@ impl Infer { ); // Initialize with base model adapter (empty) mapping to index 0 - let adapter_to_index = Arc::new(Mutex::new(HashMap::from([("".to_string(), 0)]))); + let adapter_to_index = Arc::new(Mutex::new(HashMap::from([( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + 0, + )]))); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( @@ -126,21 +132,32 @@ impl Infer { adapter_source = Some(DEFAULT_ADAPTER_SOURCE.to_string()); } + let adapter_parameters = + request + .parameters + .adapter_parameters + .clone() + .unwrap_or(AdapterParameters { + adapter_ids: vec![adapter_id.clone().unwrap()], + ..Default::default() + }); + let adapter_idx; { // TODO(travis): can optimize concurrency here using RWLock let mut adapter_to_index = self.adapter_to_index.lock().await; - if adapter_to_index.contains_key(&adapter_id.clone().unwrap()) { - adapter_idx = *adapter_to_index.get(&adapter_id.clone().unwrap()).unwrap(); + let adapter_key = adapter_parameters.clone(); + if adapter_to_index.contains_key(&adapter_key) { + adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); } else { adapter_idx = adapter_to_index.len() as u32; - adapter_to_index.insert(adapter_id.clone().unwrap(), adapter_idx); + adapter_to_index.insert(adapter_key, adapter_idx); } } let api_token = request.parameters.api_token.clone(); let adapter = Adapter::new( - adapter_id.unwrap(), + adapter_parameters, adapter_source.unwrap(), adapter_idx, api_token, diff --git a/router/src/lib.rs b/router/src/lib.rs index 7aefd84e0..19ab7a191 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,8 @@ mod queue; mod scheduler; pub mod server; mod validation; +use lorax_client::AdapterParameters as AdapterParametersMessage; +use lorax_client::{MajoritySignMethod, MergeStrategy}; use infer::Infer; use loader::AdapterLoader; @@ -65,6 +67,91 @@ pub struct Info { pub docker_label: Option<&'static str>, } +#[derive(Clone, Debug, Deserialize, ToSchema, Default)] +pub(crate) struct AdapterParameters { + #[serde(rename(deserialize = "ids"))] + #[schema(inline, example = json ! (["arnavgrg/codealpaca-qlora"]))] + pub adapter_ids: Vec, + #[serde(default)] + #[schema(inline, example = json ! ([0.25, 0.75]))] + pub weights: Vec, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "linear")] + pub merge_strategy: Option, + #[serde(default)] + #[schema(nullable = false, default = 0.0, example = 0.5)] + pub density: f32, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "total")] + pub majority_sign_method: Option, +} + +impl Into for AdapterParameters { + fn into(self) -> AdapterParametersMessage { + AdapterParametersMessage { + adapter_ids: self.adapter_ids, + weights: self.weights, + merge_strategy: MergeStrategy::from_str_name( + self.merge_strategy + .unwrap_or("linear".to_string()) + .to_uppercase() + .as_str(), + ) + .unwrap() + .into(), + density: self.density, + majority_sign_method: MajoritySignMethod::from_str_name( + self.majority_sign_method + .unwrap_or("total".to_string()) + .to_uppercase() + .as_str(), + ) + .unwrap() + .into(), + } + } +} + +impl std::hash::Hash for AdapterParameters { + fn hash(&self, state: &mut H) { + if self.adapter_ids.len() == 1 { + self.adapter_ids[0].hash(state); + return; + } + + self.adapter_ids.hash(state); + + // Convert weights vec into vec of u32 bits + let weights: Vec = self.weights.iter().map(|x| x.to_bits()).collect(); + weights.hash(state); + + self.merge_strategy.hash(state); + + // Hash the raw bits of the float, acknowledging that this + // can cause issues with different representations of the same value. + self.density.to_bits().hash(state); + + self.majority_sign_method.hash(state); + } +} + +impl PartialEq for AdapterParameters { + fn eq(&self, other: &Self) -> bool { + if self.adapter_ids.len() == 1 { + return self.adapter_ids[0] == other.adapter_ids[0]; + } + + // In this implementation, we assume that adapter order matters + self.adapter_ids == other.adapter_ids + && self.weights == other.weights + && self.merge_strategy == other.merge_strategy + && self.density == other.density // direct comparison of f32 + && self.majority_sign_method == other.majority_sign_method + } +} + +impl Eq for AdapterParameters {} + #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] @@ -77,6 +164,9 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, + #[serde(rename(deserialize = "merged_adapters"))] + #[schema(nullable = true, default = "null")] + pub adapter_parameters: Option, #[serde(default)] #[schema( nullable = true, @@ -169,6 +259,7 @@ fn default_parameters() -> GenerateParameters { GenerateParameters { adapter_id: None, adapter_source: None, + adapter_parameters: None, api_token: None, best_of: None, temperature: None, @@ -470,6 +561,7 @@ impl From for CompatGenerateRequest { parameters: GenerateParameters { adapter_id: req.model.parse().ok(), adapter_source: None, + adapter_parameters: None, api_token: None, best_of: req.best_of.map(|x| x as usize), temperature: req.temperature, @@ -503,6 +595,7 @@ impl From for CompatGenerateRequest { parameters: GenerateParameters { adapter_id: req.model.parse().ok(), adapter_source: None, + adapter_parameters: None, api_token: None, best_of: req.n.map(|x| x as usize), temperature: req.temperature, diff --git a/router/src/loader.rs b/router/src/loader.rs index 74c39b690..7f7ac0b1b 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -140,14 +140,14 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} downloaded", adapter.id()); + tracing::info!("adapter {} downloaded", adapter.as_string()); let mut locked_state = queues_state.lock().unwrap(); if locked_state.has_adapter(&adapter) { // Above check guards against the case where the adapter was terminated between the initial @@ -157,7 +157,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED downloading adapter {}", adapter.id()); + tracing::info!("FAILED downloading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "download_adapter"); let mut locked_state = queues_state.lock().unwrap(); if locked_state.has_adapter(&adapter) { @@ -186,7 +186,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} loaded", adapter.id()); + tracing::info!("adapter {} loaded", adapter.as_string()); queues_state .lock() .unwrap() @@ -203,7 +203,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED loading adapter {}", adapter.id()); + tracing::info!("FAILED loading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "load_adapter"); queues_state .lock() @@ -231,14 +231,14 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("adapter {} offloaded", adapter.id()); + tracing::info!("adapter {} offloaded", adapter.as_string()); queues_state .lock() .unwrap() @@ -247,7 +247,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("FAILED offloading adapter {}", adapter.id()); + tracing::info!("FAILED offloading adapter {}", adapter.as_string()); metrics::increment_counter!("lorax_request_failure", "err" => "offload_adapter"); queues_state .lock() @@ -273,7 +273,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { - tracing::info!("terminating adapter {} loader", adapter.id()); + tracing::info!("terminating adapter {} loader", adapter.as_string()); let mut locked_state = queues_state.lock().unwrap(); if !locked_state.has_adapter(&adapter) { diff --git a/router/src/queue.rs b/router/src/queue.rs index 26c9d6c4a..4d1cf5d78 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -134,7 +134,7 @@ impl QueueState { self.event.batching_task.notify_one(); tracing::info!( "set adapter {} status to {}", - self.adapter.id(), + self.adapter.as_string(), self.status ); } @@ -259,7 +259,7 @@ impl AdapterQueuesState { let q = self.queue_map.get_mut(adapter); if q.is_none() { // TODO(travis): remove this - tracing::error!("adapter {} not found in queue_map", adapter.id()); + tracing::error!("adapter {} not found in queue_map", adapter.as_string()); println!("{:?}", Backtrace::force_capture()); } let queue = q.unwrap(); @@ -357,6 +357,9 @@ impl AdapterQueuesState { adapters_to_remove.insert(adapter.clone()); // Start async offload process + // TODO(travis): we're being too aggressive about offloading here, we should only + // add adapters to this set if the number of active adapters is full and there are new adapters + // waiting to be loaded offload_adapters.push(adapter.clone()); } } diff --git a/router/src/validation.rs b/router/src/validation.rs index 8f985cdcc..f10949c39 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -144,11 +144,30 @@ impl Validation { seed, watermark, adapter_id, + adapter_parameters, decoder_input_details, apply_chat_template, .. } = request.parameters; + // adapter validation + // cannot specify both adapter_id and adapter_parameters + if adapter_parameters.is_some() && adapter_id.is_some() { + return Err(ValidationError::AdapterIdConflict); + } + + if adapter_parameters.is_some() { + let nadapters = adapter_parameters.as_ref().unwrap().adapter_ids.len(); + let nweights = adapter_parameters.as_ref().unwrap().weights.len(); + if nadapters < 1 { + return Err(ValidationError::AdapterIdMissing); + } + + if nadapters != nweights { + return Err(ValidationError::AdapterWeightMismatch); + } + } + // sampling must be true when best_of > 1 let best_of = best_of.unwrap_or(1); let sampling = do_sample @@ -389,13 +408,19 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("at most one of `adapter_id` or `adapters` may be provided")] + AdapterIdConflict, + #[error("at least one adapter ID must be provided when setting `adapters`")] + AdapterIdMissing, + #[error("number of adapter IDs must match number of adapter weights")] + AdapterWeightMismatch, } #[cfg(test)] mod tests { use super::*; - use crate::default_parameters; use crate::tests::get_tokenizer; + use crate::{default_parameters, AdapterParameters}; #[tokio::test] async fn test_validation_max_new_tokens() { @@ -477,7 +502,15 @@ mod tests { ..default_parameters() }, }, - Adapter::new("".to_string(), "hf".to_string(), 0, None), + Adapter::new( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + "hf".to_string(), + 0, + None, + ), ) .await { @@ -511,7 +544,15 @@ mod tests { ..default_parameters() }, }, - Adapter::new("".to_string(), "hf".to_string(), 0, None), + Adapter::new( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + "hf".to_string(), + 0, + None, + ), ) .await { @@ -529,7 +570,15 @@ mod tests { ..default_parameters() }, }, - Adapter::new("".to_string(), "hf".to_string(), 0, None), + Adapter::new( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + "hf".to_string(), + 0, + None, + ), ) .await { @@ -547,7 +596,15 @@ mod tests { ..default_parameters() }, }, - Adapter::new("".to_string(), "hf".to_string(), 0, None), + Adapter::new( + AdapterParameters { + adapter_ids: vec!["".to_string()], + ..Default::default() + }, + "hf".to_string(), + 0, + None, + ), ) .await .unwrap(); diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 2cb572289..07632613b 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -24,7 +24,8 @@ tracer = trace.get_tracer(__name__) -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] +# TODO(travis): re-enable LM_HEAD after resolving issues with outputs +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 9af1be86d..c76d57558 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -9,8 +9,8 @@ from transformers import PreTrainedTokenizerBase from lorax_server.models.types import Batch, GeneratedText -from lorax_server.pb.generate_pb2 import InfoResponse -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map +from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_and_merge_adapters from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.lora import BatchedLoraWeights, MergedLoraWeights from lorax_server.utils.weights import shard_on_dim @@ -46,10 +46,11 @@ def __init__( self.sliding_window = sliding_window # This may be set to False in the subclass constructor - self.adapter_id = adapter_id self.dynamic_adapter_loading_enabled = dynamic_adapter_loading_enabled self.batched_lora_weights: Dict[str, BatchedLoraWeights] = defaultdict(BatchedLoraWeights) self.target_to_layer = self.adapter_target_to_layer() + self.loaded_adapters = set() + self.static_adapter_id = adapter_id self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -137,48 +138,51 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return False - def load_adapter(self, adapter_id, adapter_source, adapter_index, api_token): + def load_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + api_token: str, + ): """Physically loads the adapter weights into the model. adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded into model. Otherwise, the adapter weights are merged into the model weights on the fly. """ - if adapter_id == BASE_MODEL_ADAPTER_ID: + if adapter_index in self.loaded_adapters: + # Adapter already loaded return if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") if not self.dynamic_adapter_loading_enabled: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") + + logger.info(f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}") + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_and_merge_adapters( + self.model_id, adapter_parameters, adapter_source, adapter_index, weight_names, api_token + ) - # If we are doing dynamic adapter loading, then we need to reset the weights - if adapter_id == self.adapter_id: - return - elif adapter_id != BASE_MODEL_ADAPTER_ID: - logger.info(f"Loading adapter weights into model: {adapter_id}") - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( - self.model_id, adapter_id, adapter_source, weight_names, api_token + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + self.load_batched_adapter_weights( + module_map, adapter_config, adapter_index, layer_name, unused_weight_names ) + + if len(unused_weight_names) > 0: + logger.warning(f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}") + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - self.load_batched_adapter_weights( - module_map, adapter_config, adapter_index, layer_name, unused_weight_names - ) - - if len(unused_weight_names) > 0: - logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.adapter_id = adapter_id + self.loaded_adapters.add(adapter_index) def shard_lora_weights( self, @@ -246,25 +250,28 @@ def load_batched_adapter_weights( q_lora_weights = self.batched_lora_weights[layer_type] q_lora_weights.add_adapter(adapter_index, q_lora_merged) - def offload_adapter(self, adapter_id, adapter_source, adapter_index): + def offload_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + ): """Offloads the adapter weights from GPU to CPU or disk.""" + if adapter_index not in self.loaded_adapters: + # Adapter already offloaded + return + if not self.supports_adapter_loading: raise ValueError("This model does not support adapter loading.") if not self.dynamic_adapter_loading_enabled: - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - raise ValueError(f"This model was initialized with the adapter {self.adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature.") + raise ValueError(f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature.") - if adapter_id == BASE_MODEL_ADAPTER_ID: - return - else: - for layer_name in self.adapter_layers: - if layer_name in self.batched_lora_weights: - self.batched_lora_weights[layer_name].remove_adapter(adapter_index) + for layer_name in self.adapter_layers: + if layer_name in self.batched_lora_weights: + self.batched_lora_weights[layer_name].remove_adapter(adapter_index) - self.adapter_id = BASE_MODEL_ADAPTER_ID + self.loaded_adapters.remove(adapter_index) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 5225b9ce1..066120f3d 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -19,7 +19,7 @@ from lorax_server.pb import generate_pb2_grpc, generate_pb2 from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3 -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, is_base_model class LoraxService(generate_pb2_grpc.LoraxServiceServicer): @@ -126,82 +126,88 @@ async def Decode(self, request, context): batch=next_batch.to_pb() if next_batch else None, ) - async def DownloadAdapter(self, request, context): - adapter_id = request.adapter_id - if adapter_id == BASE_MODEL_ADAPTER_ID: + async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): logger.info("No adapter to download for base model. Skipping.") - return generate_pb2.DownloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - ) + return generate_pb2.DownloadAdapterResponse(downloaded=False) api_token = request.api_token adapter_source = _adapter_source_enum_to_string(request.adapter_source) - if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) - adapter_source = S3 - try: - if adapter_source == HUB: - # Quick auth check on the repo against the token - HfApi(token=api_token).model_info(adapter_id, revision=None) - - # fail fast if ID is not an adapter (i.e. it is a full model) - # TODO(geoffrey): do this for S3– can't do it this way because the - # files are not yet downloaded locally at this point. - config_path = get_config_path(adapter_id, adapter_source) - PeftConfig.from_pretrained(config_path, token=api_token) - - download_weights(adapter_id, source=adapter_source, api_token=api_token) - return generate_pb2.DownloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - ) - except Exception: - logger.exception("Error when downloading adapter") - - if adapter_source != LOCAL: - # delete safetensors files if there is an issue downloading or converting - # the weights to prevent cache hits by subsequent calls - try: - local_path = get_local_dir(adapter_id, adapter_source) - shutil.rmtree(local_path) - except Exception as e: - logger.warning(f"Error cleaning up safetensors files after " - f"download error: {e}\nIgnoring.") - raise + for adapter_id in adapter_parameters.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + logger.info("No adapter to download for base model. Skipping.") + continue + + if adapter_source == PBASE: + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + adapter_source = S3 + try: + if adapter_source == HUB: + # Quick auth check on the repo against the token + HfApi(token=api_token).model_info(adapter_id, revision=None) + + # fail fast if ID is not an adapter (i.e. it is a full model) + # TODO(geoffrey): do this for S3– can't do it this way because the + # files are not yet downloaded locally at this point. + config_path = get_config_path(adapter_id, adapter_source) + PeftConfig.from_pretrained(config_path, token=api_token) + + download_weights(adapter_id, source=adapter_source, api_token=api_token) + except Exception: + logger.exception("Error when downloading adapter") + + if adapter_source != LOCAL: + # delete safetensors files if there is an issue downloading or converting + # the weights to prevent cache hits by subsequent calls + try: + local_path = get_local_dir(adapter_id, adapter_source) + shutil.rmtree(local_path) + except Exception as e: + logger.warning(f"Error cleaning up safetensors files after " + f"download error: {e}\nIgnoring.") + raise + + return generate_pb2.DownloadAdapterResponse(downloaded=True) - async def LoadAdapter(self, request, context): + async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): + logger.info("No adapter to load for base model. Skipping.") + return generate_pb2.LoadAdapterResponse(loaded=False) + try: - adapter_id = request.adapter_id adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index api_token = request.api_token + if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + for i in range(len(adapter_parameters.adapter_ids)): + adapter_id = adapter_parameters.adapter_ids[i] + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + adapter_parameters.adapter_ids[i] = adapter_id adapter_source = S3 - self.model.load_adapter(adapter_id, adapter_source, adapter_index, api_token) - return generate_pb2.LoadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - adapter_index=adapter_index, - ) + self.model.load_adapter(adapter_parameters, adapter_source, adapter_index, api_token) + + return generate_pb2.LoadAdapterResponse(loaded=True) except Exception: logger.exception("Error when loading adapter") raise - async def OffloadAdapter(self, request, context): + async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, context): + adapter_parameters = request.adapter_parameters + if is_base_model(adapter_parameters): + logger.info("No adapter to offload for base model. Skipping.") + return generate_pb2.OffloadAdapterResponse(offloaded=False) + try: - adapter_id = request.adapter_id + adapter_idx = request.adapter_index adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index - self.model.offload_adapter(adapter_id, adapter_source, adapter_index) + self.model.offload_adapter(adapter_idx, adapter_source, adapter_index) - return generate_pb2.OffloadAdapterResponse( - adapter_id=adapter_id, - adapter_source=request.adapter_source, - adapter_index=adapter_index, - ) + return generate_pb2.OffloadAdapterResponse(offloaded=True) except Exception: logger.exception("Error when offloading adapter") raise diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 51dfa9021..94af07185 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import os from collections import defaultdict from functools import lru_cache @@ -11,18 +12,94 @@ from peft import LoraConfig from peft.utils import transpose from safetensors.torch import load_file, save_file -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from tqdm import tqdm from filelock import FileLock -from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.pb import generate_pb2 +from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.utils.merges.strategies import merge_adapters BASE_MODEL_ADAPTER_ID = "__base_model__" +ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] + + +@dataclass +class AdapterParametersContainer: + adapter_parameters: generate_pb2.AdapterParameters + adapter_source: str + adapter_index: int + + def __hash__(self) -> int: + return self.adapter_index + + +def is_base_model(adapter_parameters: generate_pb2.AdapterParameters) -> bool: + if len(adapter_parameters.adapter_ids) != 1: + return False + return adapter_parameters.adapter_ids[0] == BASE_MODEL_ADAPTER_ID + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: generate_pb2.AdapterParameters, + adapter_source: str, + adapter_index: int, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_ids) == 1: + return load_module_map( + model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token + ) + + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_source, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, api_token) + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + merged_weight_names = set() + tokenizer = None + for adapter_id in params.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( + model_id, adapter_id, adapter_params.adapter_source, weight_names, api_token, + ) + + adapters_to_merge.append((module_map, adapter_config)) + merged_weight_names = merged_weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, merged_weight_names, tokenizer + + @lru_cache(maxsize=128) -def load_module_map(model_id, adapter_id, adapter_source, weight_names, api_token): +def load_module_map( + model_id: str, + adapter_id: str, + adapter_source: str, + weight_names: Tuple[str], + api_token: str, +) -> Tuple[ModuleMap, LoraConfig, Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with # lorax_server/utils/adapter.py::create_merged_weight_files source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) diff --git a/server/lorax_server/utils/merges/__init__.py b/server/lorax_server/utils/merges/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py new file mode 100644 index 000000000..99dd5d81a --- /dev/null +++ b/server/lorax_server/utils/merges/strategies.py @@ -0,0 +1,171 @@ +from abc import ABC +from collections import defaultdict +import copy +from typing import TYPE_CHECKING, Dict, List, Tuple, Type + +import torch +from peft import LoraConfig + +from lorax_server.pb.generate_pb2 import ( + AdapterParameters, + MajoritySignMethod as MajoritySignMethodEnum, + MergeStrategy as MergeStrategyEnum, +) +from lorax_server.utils.merges.utils import calculate_majority_sign_mask, disjoint_merge, prune + +if TYPE_CHECKING: + from lorax_server.utils.adapter import ModuleMap + + +def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def __init__(self, **kwargs): + pass + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # elect sign + majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float, **kwargs): + self.density = density + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: + # sparsify + task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # elect sign + majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry: Dict[str, Type[MergeStrategy]] = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple["ModuleMap", LoraConfig]], + merge_params: AdapterParameters, +) -> Tuple["ModuleMap", LoraConfig]: + strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() + + weights = merge_params.weights + if not weights: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + + merge_config = { + "density": merge_params.density, + "majority_sign_method": MajoritySignMethodEnum.Name(merge_params.majority_sign_method).lower(), + } + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) + lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) + + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors + for idx, (module_map, lora_config) in enumerate(adapters): + for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) + lora_configs.append(lora_config) + + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: "ModuleMap" = defaultdict(dict) + for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, param_weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List[LoraConfig]): + # check that all configs have the same rank + ranks = set(lora_config.r for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError(f"unable to merge adapters, lora configs have different ranks: {ranks}") + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError("unable to merge adapters, lora configs have no target modules") + + +def _merge_lora_configs(lora_configs: List[LoraConfig]) -> LoraConfig: + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted(set( + module for lora_config in lora_configs for module in lora_config.target_modules + )) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config diff --git a/server/lorax_server/utils/merges/utils.py b/server/lorax_server/utils/merges/utils.py new file mode 100644 index 000000000..88e2a2989 --- /dev/null +++ b/server/lorax_server/utils/merges/utils.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask(tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)