diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index eea5301e7..54d571b29 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -255,8 +255,11 @@ def load( lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights) + segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights] + if not segment_ranks: + return None + max_rank = max(segment_ranks) if prefill or max_rank > BGMV_MAX_RANK: use_sgmv = True lora_a_ptr = torch.tensor( diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index aa1a98366..fcbe6507f 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List, Optional, Type from unittest import mock import pytest @@ -7,10 +7,40 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.types import LORA -from lorax_server.adapters.weights import AdapterBatchMetadata, LayerAdapterWeights +from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +class FakeAdapterWeights(AdapterWeights): + @classmethod + def get_batch_types(cls) -> List[Type["FakeBatchAdapterWeights"]]: + return [FakeBatchAdapterWeights] + + @property + def speculative_tokens(self) -> int: + return 0 + + +class FakeBatchAdapterWeights(BatchAdapterWeights): + @classmethod + def has_adapter(self, adapter_index: int) -> bool: + False + + @classmethod + def key(cls) -> str: + "fake" + + @classmethod + def load( + cls, + adapter_weights: Dict[int, AdapterWeights], + meta: "AdapterBatchMetadata", + prefill: bool, + prefill_head_indices: torch.Tensor, + ) -> Optional["BatchAdapterWeights"]: + return None + + @pytest.mark.parametrize( "lora_ranks", [ @@ -71,4 +101,39 @@ def test_batched_lora_weights(lora_ranks: List[int]): assert rd.segment_starts.shape == (2,) assert rd.segment_ends.shape == (2,) + +def test_batched_lora_weights_no_segments(): + batched_weights = LayerAdapterWeights() + assert batched_weights.is_empty() + + h = 1024 + + # fake weights + idx = 0 + weights = FakeAdapterWeights() + batched_weights.add_adapter(idx, weights) + + # lora weights + idx = 1 + lora_rank = 16 + weights = LoraWeights( + weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], + weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)], + adapter_config=LoraConfig(r=lora_rank), + ) + batched_weights.add_adapter(idx, weights) + + assert not batched_weights.is_empty() + assert len(batched_weights.adapter_weights) == 2 + + meta = AdapterBatchMetadata( + adapter_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64), + adapter_set={0, 1}, + adapter_segments=torch.tensor([0, 4], dtype=torch.int64), + segment_indices=[0], + ) + + with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): + data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) + print(data)