Skip to content

Commit

Permalink
[Misc] Improve error message when LoRA parsing fails (vllm-project#5194)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and jimpang committed Jul 24, 2024
1 parent f41b7a4 commit 9c4eed0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
14 changes: 13 additions & 1 deletion tests/lora/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import OrderedDict

import pytest
from torch import nn

from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
from vllm.utils import LRUCache


def test_parse_fine_tuned_lora_name():
def test_parse_fine_tuned_lora_name_valid():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
Expand Down Expand Up @@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name():
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)


def test_parse_fine_tuned_lora_name_invalid():
fixture = {
"weight",
"base_model.weight",
"base_model.model.weight",
}
for name in fixture:
with pytest.raises(ValueError, match="unsupported LoRA weight"):
parse_fine_tuned_lora_name(name)


def test_replace_submodule():
model = nn.Sequential(
OrderedDict([
Expand Down
15 changes: 7 additions & 8 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
is_lora_a whether the tensor is lora_a or lora_b.
"""
parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"

if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
if parts[-1] == "weight":
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"

raise ValueError(f"{name} is unsupported format")
raise ValueError(f"{name} is unsupported LoRA weight")

0 comments on commit 9c4eed0

Please sign in to comment.