-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor adapter interface to support adapters other than LoRA (e.g.,…
… speculative decoding) (#359)
- Loading branch information
Showing
33 changed files
with
658 additions
and
729 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from lorax_server.adapters.config import AdapterConfig | ||
from lorax_server.adapters.lora import LoraConfig | ||
# from lorax_server.adapters.medusa import MedusaConfig | ||
from lorax_server.adapters.weights import AdapterBatchData, AdapterBatchMetadata | ||
|
||
|
||
def load_adapter_config( | ||
config_path: Optional[Path], | ||
adapter_config_path: Optional[Path], | ||
api_token: str, | ||
) -> AdapterConfig: | ||
if adapter_config_path is not None and adapter_config_path.exists(): | ||
return LoraConfig.load(str(adapter_config_path.parent), api_token) | ||
|
||
# TODO(travis): medusa | ||
# if config_path is not None and config_path.exists(): | ||
# config = json.load(config_path.open()) | ||
# if "medusa_num_heads" in config: | ||
# return MedusaConfig.load(config) | ||
|
||
raise ValueError(f"No valid adapter config file found: " | ||
f"tried {adapter_config_path} and {config_path}") | ||
|
||
|
||
__all__ = [ | ||
"AdapterBatchData", | ||
"AdapterBatchMetadata", | ||
"load_adapter_config", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple | ||
|
||
import torch | ||
|
||
from lorax_server.adapters.weights import AdapterWeights | ||
|
||
if TYPE_CHECKING: | ||
from server.lorax_server.models.model import Model | ||
|
||
|
||
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] | ||
|
||
|
||
@dataclass | ||
class AdapterConfig(ABC): | ||
base_model_name_or_path: str | ||
|
||
@abstractmethod | ||
def map_weights_for_model( | ||
self, adapter_weights: Dict, weight_names: Tuple[str], | ||
) -> Tuple[ModuleMap, Set[str]]: | ||
pass | ||
|
||
@abstractmethod | ||
def load_batched_adapter_weights( | ||
self, | ||
model: "Model", | ||
module_map: Dict[str, Dict], | ||
layer_type: str, | ||
unused_weight_names: Set[str], | ||
) -> Optional[AdapterWeights]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from peft import LoraConfig as _LoraConfig | ||
from torch.distributed import ProcessGroup | ||
|
||
from lorax_server.adapters.config import AdapterConfig, ModuleMap | ||
from lorax_server.adapters.types import LORA | ||
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights | ||
from lorax_server.utils.sgmv import MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank, pad_rank | ||
|
||
if TYPE_CHECKING: | ||
from lorax_server.models.model import Model | ||
|
||
EMPTY_TENSOR = torch.tensor([]) | ||
|
||
|
||
@dataclass | ||
class LoraConfig(AdapterConfig): | ||
r: int | ||
target_modules: Optional[Union[List[str], str]] | ||
fan_in_fan_out: bool | ||
lora_alpha: int | ||
use_rslora: bool | ||
|
||
def map_weights_for_model( | ||
self, adapter_weights: Dict, weight_names: Tuple[str], | ||
) -> Tuple[ModuleMap, Set[str]]: | ||
adapter_weight_names = set() | ||
module_map = {} | ||
for weight_name in weight_names: | ||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" | ||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" | ||
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: | ||
continue | ||
|
||
module_map[weight_name] = { | ||
"lora_A": (adapter_weights[lora_a_name], lora_a_name), | ||
"lora_B": (adapter_weights[lora_b_name], lora_b_name), | ||
} | ||
adapter_weight_names.add(lora_a_name) | ||
adapter_weight_names.add(lora_b_name) | ||
return module_map, adapter_weight_names | ||
|
||
def load_batched_adapter_weights( | ||
self, | ||
model: "Model", | ||
module_map: Dict[str, Dict], | ||
layer_type: str, | ||
unused_weight_names: Set[str], | ||
) -> Optional[AdapterWeights]: | ||
return LoraWeights.load( | ||
self, | ||
model, | ||
module_map, | ||
layer_type, | ||
unused_weight_names, | ||
) | ||
|
||
@classmethod | ||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": | ||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) | ||
return cls( | ||
base_model_name_or_path=hf_config.base_model_name_or_path, | ||
r=hf_config.r, | ||
target_modules=hf_config.target_modules, | ||
fan_in_fan_out=hf_config.fan_in_fan_out, | ||
lora_alpha=hf_config.lora_alpha, | ||
use_rslora=hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False, | ||
) | ||
|
||
|
||
class LoraWeights(AdapterWeights): | ||
"""LoRA weights for a single adapter merged across all layers.""" | ||
|
||
def __init__( | ||
self, | ||
weights_a: List[torch.Tensor], | ||
weights_b: List[torch.Tensor], | ||
adapter_config: LoraConfig, | ||
): | ||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 | ||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 | ||
|
||
# [num_layers, hidden_size, r] | ||
weights_a = [ | ||
orient_for_rank(w, w.size(1)).contiguous() | ||
for w in weights_a | ||
] | ||
self.weights_a = torch.stack(weights_a) | ||
|
||
# [num_layers, r, hidden_size] | ||
self.weights_b = torch.stack(weights_b) | ||
|
||
self.adapter_config = adapter_config | ||
|
||
@classmethod | ||
def get_batch_type(cls) -> BatchAdapterWeights: | ||
return BatchLoraWeights | ||
|
||
@classmethod | ||
def load( | ||
cls, | ||
config: LoraConfig, | ||
model: "Model", | ||
module_map: Dict[str, Dict], | ||
layer_type: str, | ||
unused_weight_names: Set[str], | ||
) -> Optional[AdapterWeights]: | ||
nlayers = model.get_num_layers_for_type(layer_type) | ||
lora_a_list = [None] * nlayers | ||
lora_b_list = [None] * nlayers | ||
|
||
for layer_id in range(nlayers): | ||
key = (layer_id, layer_type) | ||
weight_name, layer = model.target_to_layer[key] | ||
|
||
base_weight = layer.base_layer.linear.weight | ||
base_device = base_weight.device | ||
|
||
if weight_name not in module_map: | ||
# There is no LoRA weight for this layer type in the adapter | ||
return None | ||
|
||
lora_a, lora_a_name = module_map[weight_name]["lora_A"] | ||
lora_a = lora_a.to(base_device, model.dtype) | ||
|
||
lora_b, lora_b_name = module_map[weight_name]["lora_B"] | ||
lora_b = lora_b.to(base_device, model.dtype) | ||
|
||
scale = get_scaling_factor( | ||
config.lora_alpha, | ||
config.r, | ||
uses_rslora=config.use_rslora, | ||
) | ||
|
||
unused_weight_names.discard(lora_a_name) | ||
unused_weight_names.discard(lora_b_name) | ||
|
||
# Merge scaling factor into lora_b due to associativity of matrix multiplication: | ||
# (A * B) * C = A * (B * C) | ||
lora_a_list[layer_id] = lora_a.transpose(0, 1) | ||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale | ||
|
||
# pad lora ranks to be compatible with sgmv | ||
lora_a_list = [pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list] | ||
lora_b_list = [pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list] | ||
|
||
if lora_a_list: | ||
# update rank if it was padded | ||
padded_rank = lora_a_list[0].size(1) | ||
config.r = padded_rank | ||
|
||
return LoraWeights( | ||
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), config, | ||
) | ||
|
||
|
||
@dataclass | ||
class RankSegments: | ||
rank: int | ||
tmp_shrink: torch.Tensor | ||
tmp_expand: torch.Tensor | ||
lora_a_ptr: torch.Tensor | ||
lora_b_ptr: torch.Tensor | ||
segment_starts: torch.Tensor | ||
segment_ends: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class BatchLoraWeights(BatchAdapterWeights): | ||
lora_a: Dict[int, torch.Tensor] | ||
lora_b: Dict[int, torch.Tensor] | ||
adapter_index_configs: Dict[int, LoraConfig] | ||
rank_data: Dict[int, RankSegments] | ||
|
||
def has_adapter(self, adapter_index: int) -> bool: | ||
return adapter_index in self.adapter_index_configs | ||
|
||
def can_vectorize(self, pg: ProcessGroup) -> bool: | ||
return all( | ||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM | ||
for rank_data in self.rank_data.values() | ||
) | ||
|
||
@classmethod | ||
def key(cls) -> str: | ||
return LORA | ||
|
||
@classmethod | ||
def load(self, adapter_weights: Dict[int, LoraWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights": | ||
first_weights = list(adapter_weights.values())[0] | ||
device = first_weights.weights_a.device | ||
segment_indices = meta.segment_indices | ||
|
||
lora_a = { | ||
idx: adapter_weights[idx].weights_a | ||
for idx in segment_indices | ||
if idx in adapter_weights | ||
} | ||
lora_a_ptr = torch.tensor( | ||
[ | ||
( | ||
adapter_weights[idx].weights_a.data_ptr() | ||
if idx in adapter_weights | ||
else EMPTY_TENSOR.data_ptr() | ||
) for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
lora_b = { | ||
idx: adapter_weights[idx].weights_b | ||
for idx in segment_indices | ||
if idx in adapter_weights | ||
} | ||
lora_b_ptr = torch.tensor( | ||
[ | ||
( | ||
adapter_weights[idx].weights_b.data_ptr() | ||
if idx in adapter_weights | ||
else EMPTY_TENSOR.data_ptr() | ||
) for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
|
||
adapter_index_configs = { | ||
idx: adapter_weights[idx].adapter_config | ||
for idx in segment_indices | ||
if idx in adapter_weights | ||
} | ||
|
||
rank_indices = defaultdict(list) | ||
for segment_idx, adapter_idx in enumerate(segment_indices): | ||
if adapter_idx not in adapter_weights: | ||
continue | ||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) | ||
|
||
rank_data = {} | ||
for rank, indices in rank_indices.items(): | ||
lora_a_ptr_indices = lora_a_ptr[indices] | ||
tmp_shrink, tmp_expand = get_tmp_tensors( | ||
lora_a_ptr_indices.size(0), | ||
rank, | ||
device | ||
) | ||
|
||
rank_data[rank] = RankSegments( | ||
rank=rank, | ||
tmp_shrink=tmp_shrink, | ||
tmp_expand=tmp_expand, | ||
lora_a_ptr=lora_a_ptr_indices, | ||
lora_b_ptr=lora_b_ptr[indices], | ||
segment_starts=meta.adapter_segments[indices], | ||
segment_ends=meta.adapter_segments[[i+1 for i in indices]], | ||
) | ||
|
||
return BatchLoraWeights( | ||
lora_a=lora_a, | ||
lora_b=lora_b, | ||
adapter_index_configs=adapter_index_configs, | ||
rank_data=rank_data, | ||
) | ||
|
||
|
||
def get_scaling_factor( | ||
lora_alpha: int, | ||
r: int, | ||
uses_rslora: bool = False, | ||
) -> float: | ||
"""Computes the scaling factor for the lora weights.""" | ||
if uses_rslora: | ||
return lora_alpha / (r ** 0.5) | ||
return lora_alpha / r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @dataclass | ||
# class MedusaConfig(AdapterConfig): | ||
# medusa_num_heads: int | ||
# medusa_num_layers: int | ||
|
||
# def map_weights_for_model( | ||
# self, adapter_weights: Dict, weight_names: Tuple[str], | ||
# ) -> Tuple[ModuleMap, Set[str]]: | ||
# return adapter_weights, set(weight_names) | ||
|
||
# @classmethod | ||
# def load(cls, config: dict) -> "MedusaConfig": | ||
# return cls( | ||
# base_model_name_or_path=config["base_model_name_or_path"], | ||
# medusa_num_heads=config["medusa_num_heads"], | ||
# medusa_num_layers=config["medusa_num_layers"], | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
LORA = "lora" | ||
# MEDUSA = "medusa" |
Oops, something went wrong.