diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 82423f4b4..1c611afcf 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -261,9 +261,9 @@ def load_batched_adapter_weights( lora_a_list = [pad_rank(w, dim=1, world_size=self.world_size) for w in lora_a_list] lora_b_list = [pad_rank(w, dim=0, world_size=self.world_size) for w in lora_b_list] - if lora_b_list: + if lora_a_list: # update rank if it was padded - padded_rank = lora_b_list[0].size(0) + padded_rank = lora_a_list[0].size(1) adapter_config.r = padded_rank q_lora_merged = MergedLoraWeights( diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 0c48c7d33..3bf042c42 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -536,12 +536,12 @@ def forward_layer_type( rank_segments.segment_starts, rank_segments.segment_ends, self.layer_id, - r // self.process_group.size(), + r, ) if self.process_group.size() > 1: v = self.collect_lora_a(v) - + lora_b_sgmv_cutlass( proj, v, @@ -571,13 +571,14 @@ def forward_lora( adapter_mask: torch.Tensor, ) -> torch.Tensor: lora_a = data.lora_a[adapter_index][self.layer_id, :, :] - lora_a = orient_for_rank(lora_a, data.adapter_index_configs[adapter_index].r) - a_out = input @ lora_a + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a if self.process_group.size() > 1: a_out = self.collect_lora_a(a_out) - lora_b = data.lora_b[adapter_index][self.layer_id, :, :] result = (a_out @ lora_b) * adapter_mask return result diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 533759bf3..48cb2b30e 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -54,9 +54,18 @@ def can_vectorize(self, pg: ProcessGroup) -> bool: @dataclass class AdapterBatchMetadata: + # [batch_size] adapter_indices: torch.Tensor + + # [num_adapters] adapter_set: Set[int] + + # [num_segments + 1] adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] segment_indices: List[int] @@ -96,9 +105,12 @@ def __init__( weights_b: List[torch.Tensor], adapter_config: LoraConfig, ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + # [num_layers, hidden_size, r] weights_a = [ - orient_for_rank(w, adapter_config.r).contiguous() + orient_for_rank(w, w.size(1)).contiguous() for w in weights_a ] self.weights_a = torch.stack(weights_a) @@ -184,7 +196,7 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: for segment_idx, adapter_idx in enumerate(segment_indices): if adapter_idx not in self.lora_weights: continue - rank_indices[self.lora_weights[adapter_idx].weights_a.size(2)].append(segment_idx) + rank_indices[self.lora_weights[adapter_idx].lora_a_r].append(segment_idx) rank_data = {} for rank, indices in rank_indices.items(): diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py new file mode 100644 index 000000000..fc5a595d1 --- /dev/null +++ b/server/tests/utils/test_lora.py @@ -0,0 +1,69 @@ +from typing import List +from unittest import mock +import pytest + +import torch +from peft import LoraConfig + +from lorax_server.utils.lora import AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights +from lorax_server.utils.sgmv import MIN_RANK_CUSTOM + + +@pytest.mark.parametrize("lora_ranks", [ + [8, 16], + [32, 64], +]) +def test_batched_lora_weights(lora_ranks: List[int]): + # batch meta is hardcoded with this assumption below + assert len(lora_ranks) == 2 + + batched_weights = BatchedLoraWeights() + assert batched_weights.is_empty() + + h = 1024 + for idx, lora_rank in enumerate(lora_ranks): + weights = MergedLoraWeights( + weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], + weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)], + adapter_config=LoraConfig(r=lora_rank), + ) + assert weights.lora_a_r == lora_rank + assert weights.lora_b_r == lora_rank + + batched_weights.add_adapter(idx, weights) + + assert not batched_weights.is_empty() + assert len(batched_weights.lora_weights) == 2 + + meta = AdapterBatchMetadata( + adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), + adapter_set={0, 1}, + adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), + segment_indices=[0, 1, 0, 1], + ) + + with mock.patch("lorax_server.utils.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): + data = batched_weights.get_data(meta) + + assert len(data.lora_a) == 2 + assert data.lora_a.keys() == meta.adapter_set + assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) + assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) + + assert len(data.lora_b) == 2 + assert data.lora_b.keys() == meta.adapter_set + assert data.lora_b[0].shape == (1, lora_ranks[0], h) + assert data.lora_b[1].shape == (1, lora_ranks[1], h) + + assert len(data.rank_data) == 2 + assert data.rank_data.keys() == set(lora_ranks) + for lora_rank, rd in data.rank_data.items(): + assert rd.rank == lora_rank + + # shape in all cases is the number of segments with this rank + assert rd.lora_a_ptr.shape == (2,) + assert rd.lora_b_ptr.shape == (2,) + assert rd.segment_starts.shape == (2,) + assert rd.segment_ends.shape == (2,) + + print(data)