Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor parallelism with SGMV to use true rank of the LoRA after splitting #324

Merged
merged 7 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def load_batched_adapter_weights(
lora_a_list = [pad_rank(w, dim=1, world_size=self.world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=self.world_size) for w in lora_b_list]

if lora_b_list:
if lora_a_list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what the significance of swapping from lora_b_list to lora_a_list here is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the same, actually.

# update rank if it was padded
padded_rank = lora_b_list[0].size(0)
padded_rank = lora_a_list[0].size(1)
adapter_config.r = padded_rank

q_lora_merged = MergedLoraWeights(
Expand Down
11 changes: 6 additions & 5 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,12 @@ def forward_layer_type(
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r // self.process_group.size(),
r,
)

if self.process_group.size() > 1:
v = self.collect_lora_a(v)

lora_b_sgmv_cutlass(
proj,
v,
Expand Down Expand Up @@ -571,13 +571,14 @@ def forward_lora(
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, data.adapter_index_configs[adapter_index].r)
a_out = input @ lora_a
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

lora_a = orient_for_rank(lora_a, lora_b.size(0))

a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)

lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
result = (a_out @ lora_b) * adapter_mask
return result

Expand Down
16 changes: 14 additions & 2 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,18 @@ def can_vectorize(self, pg: ProcessGroup) -> bool:

@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor

# [num_adapters]
adapter_set: Set[int]

# [num_segments + 1]
adapter_segments: torch.Tensor

# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]


Expand Down Expand Up @@ -96,9 +105,12 @@ def __init__(
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

# [num_layers, hidden_size, r]
weights_a = [
orient_for_rank(w, adapter_config.r).contiguous()
orient_for_rank(w, w.size(1)).contiguous()
for w in weights_a
]
self.weights_a = torch.stack(weights_a)
Expand Down Expand Up @@ -184,7 +196,7 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData:
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in self.lora_weights:
continue
rank_indices[self.lora_weights[adapter_idx].weights_a.size(2)].append(segment_idx)
rank_indices[self.lora_weights[adapter_idx].lora_a_r].append(segment_idx)

rank_data = {}
for rank, indices in rank_indices.items():
Expand Down
69 changes: 69 additions & 0 deletions server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List
from unittest import mock
import pytest

import torch
from peft import LoraConfig

from lorax_server.utils.lora import AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.sgmv import MIN_RANK_CUSTOM


@pytest.mark.parametrize("lora_ranks", [
[8, 16],
[32, 64],
])
def test_batched_lora_weights(lora_ranks: List[int]):
# batch meta is hardcoded with this assumption below
assert len(lora_ranks) == 2

batched_weights = BatchedLoraWeights()
assert batched_weights.is_empty()

h = 1024
for idx, lora_rank in enumerate(lora_ranks):
weights = MergedLoraWeights(
weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)],
weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)],
adapter_config=LoraConfig(r=lora_rank),
)
assert weights.lora_a_r == lora_rank
assert weights.lora_b_r == lora_rank

batched_weights.add_adapter(idx, weights)

assert not batched_weights.is_empty()
assert len(batched_weights.lora_weights) == 2

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64),
adapter_set={0, 1},
adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64),
segment_indices=[0, 1, 0, 1],
)

with mock.patch("lorax_server.utils.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta)

assert len(data.lora_a) == 2
assert data.lora_a.keys() == meta.adapter_set
assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h))
assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h))

assert len(data.lora_b) == 2
assert data.lora_b.keys() == meta.adapter_set
assert data.lora_b[0].shape == (1, lora_ranks[0], h)
assert data.lora_b[1].shape == (1, lora_ranks[1], h)

assert len(data.rank_data) == 2
assert data.rank_data.keys() == set(lora_ranks)
for lora_rank, rd in data.rank_data.items():
assert rd.rank == lora_rank

# shape in all cases is the number of segments with this rank
assert rd.lora_a_ptr.shape == (2,)
assert rd.lora_b_ptr.shape == (2,)
assert rd.segment_starts.shape == (2,)
assert rd.segment_ends.shape == (2,)

print(data)
Loading