Skip to content

Commit

Permalink
Add support for LoRA adapters trained with Rank-Stabilized scaling (#299
Browse files Browse the repository at this point in the history
)
  • Loading branch information
arnavgarg1 authored Mar 5, 2024
1 parent f56238a commit 21631fa
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
21 changes: 15 additions & 6 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from lorax_server.models.types import Batch, GeneratedText
from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_and_merge_adapters
from lorax_server.utils.adapter import (
BASE_MODEL_ADAPTER_ID,
get_scaling_factor,
load_and_merge_adapters,
uses_rslora,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.lora import BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.sgmv import pad_rank
Expand Down Expand Up @@ -217,25 +222,29 @@ def load_batched_adapter_weights(
nlayers = self.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers

for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = self.target_to_layer[key]

base_weight = layer.base_layer.linear.weight
base_device = base_weight.device

if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, self.dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, self.dtype)

scale = adapter_config.lora_alpha / adapter_config.r
scale = get_scaling_factor(
adapter_config.lora_alpha,
adapter_config.r,
uses_rslora=uses_rslora(adapter_config),
)

unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
Expand All @@ -244,7 +253,7 @@ def load_batched_adapter_weights(
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=self.world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=self.world_size) for w in lora_b_list]
Expand Down
39 changes: 31 additions & 8 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ def compute_delta_weight(
lora_B: torch.Tensor,
fan_in_fan_out: bool,
alpha: float,
r: float
r: float,
uses_rslora: bool = False
) -> torch.Tensor:
"""Computes the delta weight for a Linear layer given A and B LoRA matrices.
TODO: add logic for other module types beyond Linear layers.
Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806
"""
scaling = alpha / r
scaling = get_scaling_factor(alpha, r, uses_rslora=uses_rslora)
delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling
return delta_weight

Expand Down Expand Up @@ -197,7 +198,7 @@ def merge_adapter_weights(
matrix_type = adapter_weight_name.split(".")[-2]
module_mapping[weight_name][matrix_type] = adapter_weight_name
processed_adapter_weight_names.add(adapter_weight_name)

# merge adapter weights into model weights
merged_weights = {}
for weight_name, adapter_weight_names in tqdm(
Expand All @@ -208,8 +209,14 @@ def merge_adapter_weights(
lora_A = adapter_weights[adapter_weight_names["lora_A"]]
lora_B = adapter_weights[adapter_weight_names["lora_B"]]
delta_weight = compute_delta_weight(
lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r)

lora_A,
lora_B,
adapter_config.fan_in_fan_out,
adapter_config.lora_alpha,
adapter_config.r,
uses_rslora=uses_rslora(adapter_config),
)

# transpose delta weight if necessary
# TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2).
# We can likely take this out once we've switched to using Linear layers.
Expand Down Expand Up @@ -292,12 +299,28 @@ def create_merged_weight_files(
return merged_weight_filenames


def uses_rslora(adapter_config: LoraConfig) -> bool:
""" Returns True if the adapter uses RSLora for scaling the delta weights. """
return adapter_config.use_rslora if hasattr(adapter_config, "use_rslora") else False


def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r ** 0.5)
return lora_alpha / r


def main():
adapter_id = "arnavgrg/codealpaca-qlora"
adapter_config = LoraConfig.from_pretrained(adapter_id)
model_id = adapter_config.base_model_name_or_path
model_weight_filenames = weight_files(model_id, extension=".safetensors")

merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames)
print(merged_adapter_filenames)

Expand Down

0 comments on commit 21631fa

Please sign in to comment.