diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 892f6081e2aaa..4ff9715b4ca8d 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,12 +1,13 @@ from collections import OrderedDict +import pytest from torch import nn from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.utils import LRUCache -def test_parse_fine_tuned_lora_name(): +def test_parse_fine_tuned_lora_name_valid(): fixture = { ("base_model.model.lm_head.lora_A.weight", "lm_head", True), ("base_model.model.lm_head.lora_B.weight", "lm_head", False), @@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name(): assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) +def test_parse_fine_tuned_lora_name_invalid(): + fixture = { + "weight", + "base_model.weight", + "base_model.model.weight", + } + for name in fixture: + with pytest.raises(ValueError, match="unsupported LoRA weight"): + parse_fine_tuned_lora_name(name) + + def test_replace_submodule(): model = nn.Sequential( OrderedDict([ diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index b0198a50b1c52..4a86c16cf64db 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: is_lora_a whether the tensor is lora_a or lora_b. """ parts = name.split(".") - assert parts[0] == "base_model" - assert parts[1] == "model" - if parts[-1] == "weight": - assert parts[-2] == "lora_A" or parts[-2] == "lora_B" - return ".".join(parts[2:-2]), parts[-2] == "lora_A" - if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": + if parts[-1] == "weight": + if parts[-2] == "lora_A" or parts[-2] == "lora_B": + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" - raise ValueError(f"{name} is unsupported format") + raise ValueError(f"{name} is unsupported LoRA weight")