-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tensor parallelism with SGMV to use true rank of the LoRA after splitting #324
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List | ||
from unittest import mock | ||
import pytest | ||
|
||
import torch | ||
from peft import LoraConfig | ||
|
||
from lorax_server.utils.lora import AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights | ||
from lorax_server.utils.sgmv import MIN_RANK_CUSTOM | ||
|
||
|
||
@pytest.mark.parametrize("lora_ranks", [ | ||
[8, 16], | ||
[32, 64], | ||
]) | ||
def test_batched_lora_weights(lora_ranks: List[int]): | ||
# batch meta is hardcoded with this assumption below | ||
assert len(lora_ranks) == 2 | ||
|
||
batched_weights = BatchedLoraWeights() | ||
assert batched_weights.is_empty() | ||
|
||
h = 1024 | ||
for idx, lora_rank in enumerate(lora_ranks): | ||
weights = MergedLoraWeights( | ||
weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], | ||
weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)], | ||
adapter_config=LoraConfig(r=lora_rank), | ||
) | ||
assert weights.lora_a_r == lora_rank | ||
assert weights.lora_b_r == lora_rank | ||
|
||
batched_weights.add_adapter(idx, weights) | ||
|
||
assert not batched_weights.is_empty() | ||
assert len(batched_weights.lora_weights) == 2 | ||
|
||
meta = AdapterBatchMetadata( | ||
adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), | ||
adapter_set={0, 1}, | ||
adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), | ||
segment_indices=[0, 1, 0, 1], | ||
) | ||
|
||
with mock.patch("lorax_server.utils.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): | ||
data = batched_weights.get_data(meta) | ||
|
||
assert len(data.lora_a) == 2 | ||
assert data.lora_a.keys() == meta.adapter_set | ||
assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) | ||
assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) | ||
|
||
assert len(data.lora_b) == 2 | ||
assert data.lora_b.keys() == meta.adapter_set | ||
assert data.lora_b[0].shape == (1, lora_ranks[0], h) | ||
assert data.lora_b[1].shape == (1, lora_ranks[1], h) | ||
|
||
assert len(data.rank_data) == 2 | ||
assert data.rank_data.keys() == set(lora_ranks) | ||
for lora_rank, rd in data.rank_data.items(): | ||
assert rd.rank == lora_rank | ||
|
||
# shape in all cases is the number of segments with this rank | ||
assert rd.lora_a_ptr.shape == (2,) | ||
assert rd.lora_b_ptr.shape == (2,) | ||
assert rd.segment_starts.shape == (2,) | ||
assert rd.segment_ends.shape == (2,) | ||
|
||
print(data) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what the significance of swapping from
lora_b_list
tolora_a_list
here is.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the same, actually.