Skip to content

Commit

Permalink
lora: add scaling factor support for LoRA at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 7, 2024
1 parent ef99a56 commit e490188
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
5 changes: 3 additions & 2 deletions aphrodite/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
runtime_scaling: Optional[float] = 1.0,
) -> None:
self.module_name = module_name
self.rank = rank
Expand All @@ -27,9 +28,9 @@ def __init__(
self.embeddings_tensor = embeddings_tensor

if scaling is None:
self.scaling = self.lora_alpha / self.rank
self.scaling = (self.lora_alpha / self.rank) * runtime_scaling
else:
self.scaling = scaling
self.scaling = scaling * runtime_scaling

def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
Expand Down
16 changes: 15 additions & 1 deletion aphrodite/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
rank: int,
loras: Dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
runtime_scaling: Optional[float] = 1.0,
) -> None:
"""
Args:
Expand All @@ -73,10 +74,23 @@ def __init__(
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
self.runtime_scaling = runtime_scaling
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
self.loras = {
name: LoRALayerWeights(
module_name=weight.module_name,
rank=weight.rank,
lora_alpha=weight.lora_alpha,
lora_a=weight.lora_a,
lora_b=weight.lora_b,
embeddings_tensor=weight.embeddings_tensor,
scaling=weight.scaling,
runtime_scaling=runtime_scaling
)
for name, weight in loras.items()
}

def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class LoRARequest(
lora_name: str
lora_int_id: int
lora_path: str = ""
scaling_factor: Optional[float] = 1.0
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
Expand All @@ -44,6 +45,10 @@ def __post_init__(self):
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"

# Scaling factor must be non-negative
assert self.scaling_factor is None or self.scaling_factor >= 0, \
"scaling_factor must be non-negative"

@property
def adapter_id(self):
return self.lora_int_id
Expand Down

0 comments on commit e490188

Please sign in to comment.