Skip to content

Commit

Permalink
[Misc][LoRA] Add PEFTHelper for LoRA (vllm-project#11003)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 10, 2024
1 parent beb16b2 commit d05f886
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 28 deletions.
58 changes: 55 additions & 3 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import Dict, List

Expand All @@ -13,6 +14,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
Expand All @@ -30,18 +32,68 @@
]


def test_peft_helper(sql_lora_files):
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
config = dict(
r=8,
lora_alpha=16,
target_modules=["gate_proj"],
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support RSLoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_dora=True)
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))

lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)

peft_helper = PEFTHelper.from_dict(config)
lora_model = LoRAModel.from_lora_tensors(
1,
8,
16,
tensors,
device,
peft_helper=peft_helper,
device=device,
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
Expand Down
18 changes: 18 additions & 0 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.types

from vllm.lora.peft_helper import PEFTHelper
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -59,6 +60,23 @@ def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0

@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)

@classmethod
def create_dummy_lora_weights(
cls,
Expand Down
42 changes: 17 additions & 25 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
Expand Down Expand Up @@ -104,14 +105,12 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
Expand All @@ -135,10 +134,9 @@ def from_lora_tensors(
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
None,
lora_embeddings_tensor)
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor)

if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
Expand Down Expand Up @@ -170,7 +168,11 @@ def from_lora_tensors(

for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)

return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_scaling_factor)

@classmethod
def from_local_checkpoint(
Expand Down Expand Up @@ -212,6 +214,9 @@ def from_local_checkpoint(
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)

config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
Expand Down Expand Up @@ -242,7 +247,7 @@ def from_local_checkpoint(
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
Expand All @@ -256,7 +261,7 @@ def from_local_checkpoint(
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
config["target_modules"], expected_lora_modules):
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
Expand All @@ -274,30 +279,17 @@ def from_local_checkpoint(
embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device)

rank = config["r"]
lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))

return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
embedding_padding_modules=embedding_padding_modules)


class LoRAModelManager(AdapterModelManager):
Expand Down
70 changes: 70 additions & 0 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union


@dataclass
class PEFTHelper:
# Required fields
r: int
lora_alpha: int
target_modules: Union[list[str], str]

bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
use_rslora: bool = field(default=False)
use_dora: bool = field(default=False)
# long lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None)

def _validate_features(self):
error_msg = []

if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")

if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")

if error_msg:
raise ValueError(f"{', '.join(error_msg)}")

def __post_init__(self):
self._validate_features()
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))

@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}

# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(
f"Missing required configuration fields: {missing_fields}")

# Filter out fields that aren't defined in the class
filtered_dict = {
k: v
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)

0 comments on commit d05f886

Please sign in to comment.