diff --git a/aphrodite/lora/lora.py b/aphrodite/lora/lora.py index 1ba7082cc..325ae451b 100644 --- a/aphrodite/lora/lora.py +++ b/aphrodite/lora/lora.py @@ -18,6 +18,7 @@ def __init__( lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor] = None, scaling: Optional[float] = None, + runtime_scaling: Optional[float] = 1.0, ) -> None: self.module_name = module_name self.rank = rank @@ -27,9 +28,9 @@ def __init__( self.embeddings_tensor = embeddings_tensor if scaling is None: - self.scaling = self.lora_alpha / self.rank + self.scaling = (self.lora_alpha / self.rank) * runtime_scaling else: - self.scaling = scaling + self.scaling = scaling * runtime_scaling def optimize(self) -> "LoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" diff --git a/aphrodite/lora/models.py b/aphrodite/lora/models.py index 8d94986bc..e8208d193 100644 --- a/aphrodite/lora/models.py +++ b/aphrodite/lora/models.py @@ -60,6 +60,7 @@ def __init__( rank: int, loras: Dict[str, LoRALayerWeights], scaling_factor: Optional[float] = None, + runtime_scaling: Optional[float] = 1.0, ) -> None: """ Args: @@ -73,10 +74,23 @@ def __init__( # Scaling factor for long context lora model. None if it is not # fine tuned for the long context. self.scaling_factor = scaling_factor + self.runtime_scaling = runtime_scaling assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank - self.loras: Dict[str, LoRALayerWeights] = loras + self.loras = { + name: LoRALayerWeights( + module_name=weight.module_name, + rank=weight.rank, + lora_alpha=weight.lora_alpha, + lora_a=weight.lora_a, + lora_b=weight.lora_b, + embeddings_tensor=weight.embeddings_tensor, + scaling=weight.scaling, + runtime_scaling=runtime_scaling + ) + for name, weight in loras.items() + } def clone(self, lora_model_id: int) -> "LoRAModel": """Return a copy of the object with different ids. diff --git a/aphrodite/lora/request.py b/aphrodite/lora/request.py index 964b1b675..1fb8ceca2 100644 --- a/aphrodite/lora/request.py +++ b/aphrodite/lora/request.py @@ -26,6 +26,7 @@ class LoRARequest( lora_name: str lora_int_id: int lora_path: str = "" + scaling_factor: Optional[float] = 1.0 lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None __hash__ = AdapterRequest.__hash__ @@ -44,6 +45,10 @@ def __post_init__(self): # Ensure lora_path is not empty assert self.lora_path, "lora_path cannot be empty" + # Scaling factor must be non-negative + assert self.scaling_factor is None or self.scaling_factor >= 0, \ + "scaling_factor must be non-negative" + @property def adapter_id(self): return self.lora_int_id