Skip to content

Commit

Permalink
Merge multiple LoRA adapters per request (linear, TIES, DARE) (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Feb 1, 2024
1 parent f525ca6 commit 3b4c973
Show file tree
Hide file tree
Showing 25 changed files with 1,051 additions and 193 deletions.
4 changes: 2 additions & 2 deletions clients/python/lorax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.2.1"
__version__ = "0.3.0"

from lorax.client import Client, AsyncClient
from lorax.client import Client, AsyncClient, MergedAdapters
17 changes: 17 additions & 0 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Response,
Request,
Parameters,
MergedAdapters,
)
from lorax.errors import parse_error

Expand Down Expand Up @@ -63,6 +64,7 @@ def generate(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
Expand All @@ -89,6 +91,8 @@ def generate(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
merged_adapters (`Optional[MergedAdapters]`):
Merged adapters to apply to the base model for the request
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Expand Down Expand Up @@ -130,6 +134,7 @@ def generate(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
merged_adapters=merged_adapters,
api_token=api_token,
best_of=best_of,
details=True,
Expand Down Expand Up @@ -166,6 +171,7 @@ def generate_stream(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
Expand All @@ -190,6 +196,8 @@ def generate_stream(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
merged_adapters (`Optional[MergedAdapters]`):
Merged adapters to apply to the base model for the request
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Expand Down Expand Up @@ -227,6 +235,7 @@ def generate_stream(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
merged_adapters=merged_adapters,
api_token=api_token,
best_of=None,
details=True,
Expand Down Expand Up @@ -329,6 +338,7 @@ async def generate(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
Expand All @@ -355,6 +365,8 @@ async def generate(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
merged_adapters (`Optional[MergedAdapters]`):
Merged adapters to apply to the base model for the request
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Expand Down Expand Up @@ -396,6 +408,7 @@ async def generate(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
merged_adapters=merged_adapters,
api_token=api_token,
best_of=best_of,
details=True,
Expand Down Expand Up @@ -430,6 +443,7 @@ async def generate_stream(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
Expand All @@ -454,6 +468,8 @@ async def generate_stream(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
merged_adapters (`Optional[MergedAdapters]`):
Merged adapters to apply to the base model for the request
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Expand Down Expand Up @@ -491,6 +507,7 @@ async def generate_stream(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
merged_adapters=merged_adapters,
api_token=api_token,
best_of=None,
details=True,
Expand Down
57 changes: 57 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,63 @@


ADAPTER_SOURCES = ["hub", "local", "s3", "pbase"]
MERGE_STRATEGIES = ["linear", "ties", "dare_linear", "dare_ties"]
MAJORITY_SIGN_METHODS = ["total", "frequency"]


class MergedAdapters(BaseModel):
# IDs of the adapters to merge
ids: List[str]
# Weights of the adapters to merge
weights: List[float]
# Merge strategy
merge_strategy: Optional[str]
# Density
density: float
# Majority sign method
majority_sign_method: Optional[str]

@validator("ids")
def validate_ids(cls, v):
if not v:
raise ValidationError("`ids` cannot be empty")
return v

@validator("weights")
def validate_weights(cls, v, values):
ids = values["ids"]
if not v:
raise ValidationError("`weights` cannot be empty")
if len(ids) != len(v):
raise ValidationError("`ids` and `weights` must have the same length")
return v

@validator("merge_strategy")
def validate_merge_strategy(cls, v):
if v is not None and v not in MERGE_STRATEGIES:
raise ValidationError(f"`merge_strategy` must be one of {MERGE_STRATEGIES}")
return v

@validator("density")
def validate_density(cls, v):
if v < 0 or v > 1.0:
raise ValidationError("`density` must be >= 0.0 and <= 1.0")
return v

@validator("majority_sign_method")
def validate_majority_sign_method(cls, v):
if v is not None and v not in MAJORITY_SIGN_METHODS:
raise ValidationError(f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}")
return v


class Parameters(BaseModel):
# The ID of the adapter to use
adapter_id: Optional[str]
# The source of the adapter to use
adapter_source: Optional[str]
# Adapter merge parameters
merged_adapters: Optional[MergedAdapters]
# API token for accessing private adapters
api_token: Optional[str]
# Activate logits sampling
Expand Down Expand Up @@ -49,6 +99,13 @@ class Parameters(BaseModel):
# Get decoder input token logprobs and ids
decoder_input_details: bool = False

@validator("adapter_id")
def valid_adapter_id(cls, v, values):
merged_adapters = values.get("merged_adapters")
if v is not None and merged_adapters is not None:
raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`")
return v

@validator("adapter_source")
def valid_adapter_source(cls, v):
if v is not None and v not in ADAPTER_SOURCES:
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "lorax-client"
packages = [
{include = "lorax"}
]
version = "0.2.1"
version = "0.3.0"
description = "LoRAX Python Client"
license = "Apache-2.0"
authors = ["Travis Addair <[email protected]>", "Olivier Dehaene <[email protected]>"]
Expand Down
127 changes: 127 additions & 0 deletions docs/guides/merging_adapters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Merging Adapters

In LoRAX, multiple LoRA adapters can be merged together per request to create powerful multi-task ensembles
using one of several different [merge strategies](#merge-strategies).

This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without
requiring them to specify the type of task they wish to perform.

## Background: Model Merging

Model merging is a set of techniques popularized by frameworks like [mergekit](https://github.com/cg123/mergekit) that allow taking
multiple specialized fine-tuned models and combining their weights together to output a single model that can perform each of these
tasks with a much smaller total footprint.

A common use case could be to train specialized LoRA adapters for tasks like SQL generation, customer support email
generation, and information extraction. Without model merging, the user submitting their query will need to know in advance which
of these models to route their query to. With model merging, the user should be able to submit their query without prior knowledge
of which backing adapter is best suited to respond to the query.

In some cases the mixing of adapter specializations could even result in a better final response. For example, by mixing an adapter that understand math with an adapter that can provide detailed and intuitive explanations, the user could in theory get correct answers to math questions with detailed step-by-step reasoning to aide in the user's learning.

## Merge Strategies

LoRAX provides a number of model merging methods taken from [mergekit](https://github.com/cg123/mergekit) and [PEFT](https://github.com/huggingface/peft).

Options:

- `linear` (default)
- `ties`
- `dare_linear`
- `dare_ties`

### Linear

The default and most straightforward way to merge model adapters is to linearly combine each of the parameters as a weighted average. This idea was
explored in the context of merging fine-tuned models in [Model Soups](https://arxiv.org/abs/2203.05482).

Parameters:

- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request.

### TIES

[TIES](https://arxiv.org/abs/2306.01708) is based on the idea of [Task Arithmetic](https://arxiv.org/abs/2212.04089), whereby the fine-tuned models
are merged after subtracting out the base model weights. LoRA and other adapters are already task-specific tensors,
so this approach is a natural fit when merging LoRAs.

To resolve interference between adapters, the weights are sparsified and a sign-based consensus algorithms is used to determine the weighted average.

One the strengths of this approach is its ability to scale well to large numbers of adapters and retain each of their strengths.

Parameters:

- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request.
- `density` (required): fraction of weights in adapters to retain.
- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus.

### DARE (Linear)

[DARE](https://arxiv.org/abs/2311.03099), like TIES, sparsifies adapter weights (task vectors) to reduce interference. Unlike TIES, however,
DARE uses random pruning and rescaling in an attempt to better match performance of the independent adapters.

Parameters:

- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request.
- `density` (required): fraction of weights in adapters to retain.

### DARE (TIES)

DARE method from above that also applies the sign consensus algorithm from TIES.

Parameters:

- `weights` (default: `[1, ..]`): relative weight of each of the adapters in the request.
- `density` (required): fraction of weights in adapters to retain.
- `majority_sign_method` (default: `total`): one of `{total, frequency}` used to obtain the magnitude of the sign for consensus.

## Example

This example is derived from the [PEFT example](https://github.com/huggingface/peft/blob/smangrul/add-new-merging-methods/examples/multi_adapter_examples/Lora_Merging.ipynb) for model merging.

First deploy LoRAX using the base model `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`, then run the following using
the [LoRAX Python Client](../reference/python_client.md):

```python
from lorax import Client, MergedAdapters

client = Client(endpoint_url)

# tinyllama merge
merged_adapters = MergedAdapters(
ids=[
"smangrul/tinyllama_lora_norobots",
"smangrul/tinyllama_lora_sql",
"smangrul/tinyllama_lora_adcopy",
],
weights=[2.0, 0.3, 0.7],
merge_strategy="ties",
density=0.2,
majority_sign_method="total",
)

# norobots
prompt = """<s><|im_start|>user
Write an essay about Generative AI.<|im_end|>
<|im_start|>assistant \n"""
response = client.generate(prompt, merged_adapters=merged_adapters)
print(response.generated_text)

# adcopy
prompt = """<s><|im_start|>system
Create a text ad given the following product and description.<|im_end|>
<|im_start|>user
Product: Sony PS5 PlayStation Console
Description: The PS5™ console unleashes new gaming possibilities that you never anticipated.<|im_end|>
<|im_start|>assistant \n"""
response = client.generate(prompt, merged_adapters=merged_adapters)
print(response.generated_text)

# sql
prompt = """<s> Table: 2-11365528-2
Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location']
Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic?
SQL Query:"""
response = client.generate(prompt, merged_adapters=merged_adapters)
print(response.generated_text)
```
9 changes: 9 additions & 0 deletions docs/models/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ Usage:
}
```

## Merging Adapters

Multiple adapters can be mixed / merged together per request to create powerful ensembles of different specialized adapters.

This is particularly useful when you want your LLM to be capable of handling multiple types of tasks based on the user's prompt without
requiring them to specify the type of task they wish to perform.

See [Merging Adapters](../guides/merging_adapters.md) for details.

## Private Adapter Repositories

For hosted adapter repositories like HuggingFace Hub and [Predibase](https://predibase.com/), you can perform inference using private adapters per request.
Expand Down
Loading

0 comments on commit 3b4c973

Please sign in to comment.