Skip to content

Commit

Permalink
Apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 3, 2024
1 parent 3ca53b9 commit be72d33
Show file tree
Hide file tree
Showing 76 changed files with 1,978 additions and 2,080 deletions.
9 changes: 5 additions & 4 deletions server/lorax_server/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def load_adapter_config(
) -> AdapterConfig:
if adapter_config_path is not None and adapter_config_path.exists():
return LoraConfig.load(str(adapter_config_path.parent), api_token)

if config_path is not None and config_path.exists():
config = json.load(config_path.open())
if "medusa_num_heads" in config:
return MedusaConfig.load(config)

raise ValueError(f"No valid adapter config file found: "
f"tried {adapter_config_path} and {config_path}")

raise ValueError(
f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}"
)


__all__ = [
Expand Down
8 changes: 5 additions & 3 deletions server/lorax_server/adapters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ class AdapterConfig(ABC):

@abstractmethod
def map_weights_for_model(
self, adapter_weights: Dict, weight_names: Tuple[str],
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass

@abstractmethod
def load_batched_adapter_weights(
self,
self,
model: "Model",
module_map: Dict[str, Dict],
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
Expand Down
79 changes: 35 additions & 44 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class LoraConfig(AdapterConfig):
use_rslora: bool

def map_weights_for_model(
self, adapter_weights: Dict, weight_names: Tuple[str],
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
Expand All @@ -35,19 +37,19 @@ def map_weights_for_model(
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue

module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names

def load_batched_adapter_weights(
self,
self,
model: "Model",
module_map: Dict[str, Dict],
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
Expand All @@ -59,7 +61,7 @@ def load_batched_adapter_weights(
layer_type,
unused_weight_names,
)

@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
Expand All @@ -86,27 +88,24 @@ def __init__(
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

# [num_layers, hidden_size, r]
weights_a = [
orient_for_rank(w, w.size(1)).contiguous()
for w in weights_a
]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self.weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
self.weights_b = torch.stack(weights_b)

self.adapter_config = adapter_config

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
return BatchLoraWeights

@classmethod
def load(
cls,
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
Expand Down Expand Up @@ -155,7 +154,8 @@ def load(
config.r = padded_rank

return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), config,
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
config,
)


Expand All @@ -176,60 +176,55 @@ class BatchLoraWeights(BatchAdapterWeights):
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]

def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values()
)

@classmethod
def key(cls) -> str:
return LORA

@classmethod
def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights":
adapter_weights = {
k: v
for k, v in adapter_weights.items()
if isinstance(v, LoraWeights)
}
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata
) -> "BatchLoraWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}

first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights
}
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
) for idx in segment_indices
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights
}
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
) for idx in segment_indices
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
Expand All @@ -250,11 +245,7 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet
rank_data = {}
for rank, indices in rank_indices.items():
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0),
rank,
device
)
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)

rank_data[rank] = RankSegments(
rank=rank,
Expand All @@ -263,11 +254,11 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet
lora_a_ptr=lora_a_ptr_indices,
lora_b_ptr=lora_b_ptr[indices],
segment_starts=meta.adapter_segments[indices],
segment_ends=meta.adapter_segments[[i+1 for i in indices]],
segment_ends=meta.adapter_segments[[i + 1 for i in indices]],
)

return BatchLoraWeights(
lora_a=lora_a,
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
Expand All @@ -281,5 +272,5 @@ def get_scaling_factor(
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r ** 0.5)
return lora_alpha / (r**0.5)
return lora_alpha / r
38 changes: 15 additions & 23 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ class MedusaConfig(AdapterConfig):
medusa_num_layers: int

def map_weights_for_model(
self, adapter_weights: Dict, weight_names: Tuple[str],
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
# TODO(travis): this isn't technically the ModuleMap structure, make this more generic
return adapter_weights, set(weight_names)

def load_batched_adapter_weights(
self,
self,
model: "Model",
module_map: Dict[str, Dict],
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
Expand All @@ -37,7 +39,7 @@ def load_batched_adapter_weights(
"Dynamic adapter loading is not supported for Medusa at this time. "
"Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every request."
)

return MedusaWeights.load(
self,
model,
Expand All @@ -58,9 +60,7 @@ def load(cls, config: dict) -> "MedusaConfig":
class ResBlock(torch.nn.Module):
def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()

def forward(self, x):
Expand All @@ -77,9 +77,7 @@ def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)

def forward(self, x):
for block in self.blocks:
Expand Down Expand Up @@ -107,21 +105,21 @@ class MedusaWeights(AdapterWeights):
def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"):
self.config = config
self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype))

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
return BatchMedusaWeights

@property
def speculative_tokens(self) -> int:
return self.config.medusa_num_heads

@classmethod
def load(
cls,
cls,
config: MedusaConfig,
model: "Model",
module_map: Dict[str, Dict],
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
Expand All @@ -146,16 +144,10 @@ def key(cls) -> str:
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata"
) -> "BatchMedusaWeights":
adapter_weights = {
k: v
for k, v in adapter_weights.items()
if isinstance(v, MedusaWeights)
}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}

adapter_to_medusa = {
idx: adapter_weights[idx]
for idx in meta.segment_indices
if idx in adapter_weights
idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights
}

return BatchMedusaWeights(
Expand Down
12 changes: 6 additions & 6 deletions server/lorax_server/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ def download_adapter(
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token)
adapter_source = S3

if adapter_source == HUB:
# Quick auth check on the repo against the token
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
source = get_model_source(
adapter_source, adapter_id, extension=".safetensors", api_token=api_token
)
source.load_config()

download_weights(
adapter_id, source=adapter_source, api_token=api_token
)
download_weights(adapter_id, source=adapter_source, api_token=api_token)

# Calculate size of adapter to be loaded
return source.get_weight_bytes()
Loading

0 comments on commit be72d33

Please sign in to comment.